Getting Started¶
Overview¶
This semester, all homeworks will be conducted through Google Colab notebooks. All code for the homework assignment will be written and run in this notebook. Running in Colab will automatically provide a GPU, but you may also run this notebook locally by following these instructions if you wish to use your own GPU.
You will save images in the notebooks to use and fill out a given LaTeX template which will be submitted to Gradescope, along with your notebook code.
Using Colab¶
On the left-hand side, you can click the different icons to see a Table of Contents of the assignment, as well as local files accessible through the notebook.
Make sure to go to Runtime -> Change runtime type and select GPU as the hardware accelerator. This allows you to use a GPU. Run the cells below to get started on the assignment. Note that a session is open for a maximum of 12 hours, and using too much GPU compute may result in restricted access for a short period of time. Please start the homework early so you have ample time to work.
If you load this notebook by clicking "Open in Colab" from github, you will need to save it to your own Google Drive to keep your work.
General Tips¶
In each homework problem, you will implement autoregressive models and run it on various datasets. Oftentime you will run it on two datasets (dataset 1 and dataset 2). In these cases, the expected outputs for dataset 1 are already provided to help as a sanity check.
Feel free to print whatever output (e.g. debugging code, training code, etc) you want, as the graded submission will be the submitted pdf with images.
After you complete the assignment, download all of the images outputted in the results/ folder and upload them to the figure folder in the given latex template.
There is a lot of freedom in this homework to design write and design your own models. Hyperparameters are given as a guide to show what worked for us, but feel free to explore and use what you find is best!
Run the cells below to download and load up the starter code.
# !if [ -d deepul ]; then rm -Rf deepul; fi
# !git clone https://github.com/rll/deepul.git
# !unzip -qq deepul/homeworks/hw1/data/hw1_data.zip -d deepul/homeworks/hw1/data/
# !pip install ./deepul
import numpy as np
import copy
# import jax.numpy as np
from deepul.hw1_helper import (
# Q1
visualize_q1_data,
q1_sample_data_1,
q1_sample_data_2,
q1_save_results,
# Q2
q2a_save_results,
q2b_save_results,
visualize_q2a_data,
visualize_q2b_data,
# Q3
q3ab_save_results,
q3c_save_results,
# Q4
q4a_save_results,
q4b_save_results,
# Q5
visualize_q5_data,
q5a_save_results,
# Q6
visualize_q6_data,
q6a_save_results,
)
Question 1: 1D Data¶
In this question, we will train simple generative models on discrete 1D data.
Execute the cell below to visualize our datasets
visualize_q1_data(dset_type=1)
visualize_q1_data(dset_type=2)
Dataset 1
Dataset 2
Part (a) Fitting a Histogram¶
Let $\theta = (\theta_0, \dots, \theta_{d-1}) \in \mathbb{R}^{d}$ and define the model $p_\theta(x) = \frac{e^{\theta_x}}{\sum_{x'}e^{\theta_{x'}}}$
Fit $p_\theta$ with maximum likelihood via stochastic gradient descent on the training set, using $\theta$ initialized to zero. Use your favorite version of stochastic gradient descent, and optimize your hyperparameters on a validation set of your choice.
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- Plot the model probabilities in a bar graph with $\{0,\dots,d-1\}$ on the x-axis and a real number in $[0,1]$ on the y-axis.
Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.
def softmax_probs(theta):
"""Computes all softmax probabilities for given theta."""
theta_shifted = theta - np.max(theta) # Stability trick
exp_theta = np.exp(theta_shifted)
return exp_theta / np.sum(exp_theta)
def q1_a(train_data, test_data, d, dset_id):
"""
Optimized version of q1_a for better performance
"""
train_losses = []
test_losses = []
theta = np.zeros(d, dtype=float)
y = np.bincount(train_data, minlength=d) / len(train_data)
print("train_data shape: ", train_data.shape)
print("d: ", d)
print("dset_id", dset_id)
# Hyperparameters
epochs = 200000
learning_rate = 0.1
# Calculate test loss only periodically
eval_interval = 10
# Pre-allocate arrays
train_losses = np.zeros(epochs)
test_losses = np.zeros(epochs)
for epoch in range(epochs):
probs = softmax_probs(theta)
nll = -np.mean(np.log(probs[train_data]))
d_nll = probs - y
theta -= learning_rate * d_nll
train_losses[epoch] = nll
# Calculate test loss only periodically
if epoch % eval_interval == 0:
t_nll = -np.mean(np.log(probs[test_data]))
test_losses[epoch] = t_nll
elif epoch > 0: # Copy the previous value for non-evaluated epochs
test_losses[epoch] = test_losses[epoch-1]
distribution = softmax_probs(theta)
return train_losses, test_losses, distribution
Results¶
Once you've implemented q1_a, execute the cells below to visualize and save your results
q1_save_results(1, 'a', q1_a)
train_data shape: (800,) d: 20 dset_id 1
Final Test Loss: 2.5434
q1_save_results(2, 'a', q1_a)
train_data shape: (8000,) d: 100 dset_id 2 Final Test Loss: 3.6897
Part (b) Fitting Discretized Mixture of Logistics¶
Let us model $p_\theta(x)$ as a discretized mixture of 4 logistics such that $p_\theta(x) = \sum_{i=1}^4 \pi_i[\sigma((x+0.5 - \mu_i)/s_i) - \sigma((x-0.5-\mu_i)/s_i)]$
For the edge case of when $x = 0$, we replace $x-0.5$ by $-\infty$, and for $x = 99$, we replace $x+0.5$ by $\infty$.
You may find the PixelCNN++ helpful for more information on discretized mixture of logistics.
Provide the same set of corresponding deliverables as part (a)
Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
class MixtureOfLogistics(nn.Module):
"""Mixture of Logistics distribution model for discrete data."""
def __init__(self, d, n_mix=4):
"""
Initialize the Mixture of Logistics model.
Args:
d: Number of possible discrete values (0 to d-1)
n_mix: Number of mixture components
"""
super().__init__()
self.d = d
self.n_mix = n_mix
# TODO: Initialize model parameters
# 1. Mixture weights (logit_probs)
# 2. Component means (means)
# 3. Component scales (log_scales)
self.logit_probs = nn.Parameter(torch.zeros(n_mix)) # shape (n_mix,)
init_means = torch.linspace(0,d-1,n_mix) + torch.randn(n_mix)*0.1 #shape (n_mix,)
self.means = nn.Parameter(torch.clamp(init_means,0,d-1))
self.log_scales = nn.Parameter(torch.ones(n_mix)*0.5) # shape (n_mix,1)
def forward(self, x):
"""
Compute the log probability of each value in x.
Args:
x: tensor of shape (batch_size,) containing integers in {0, ..., d-1}
Returns:
tensor of shape (batch_size,) containing log probabilities
"""
# TODO: Implement forward pass
# 1. Get mixture weights using softmax on logit_probs
# 2. Ensure scales are positive (e.g., using softplus)
# 3. Calculate CDF at x+0.5 and x-0.5 for each component
# 4. Handle edge cases (x=0 and x=d-1)
# 5. Compute probabilities from CDF differences
# 6. Weight by mixture weights and sum
# 7. Return log probabilities
probs = torch.softmax(self.logit_probs, dim=0) # n_mix
probs = probs.unsqueeze(0) # 1, n_mix
scales = nn.functional.softplus(self.log_scales) # n_mix
scales = scales.unsqueeze(0) # 1, n_mix
x = x.unsqueeze(1) # batch_size,1
x_float = x.float()
first_term = torch.sigmoid((x_float + 0.5 - self.means) / scales)
second_term = torch.sigmoid((x_float - 0.5 - self.means) / scales)
# Handle edge cases
is_d_minus_one = (x == self.d - 1)
is_zero = (x == 0)
first_term = torch.where(is_d_minus_one, torch.ones_like(first_term), first_term)
second_term = torch.where(is_zero, torch.zeros_like(second_term), second_term)
# Calculate component probabilities
component_probs = first_term - second_term # batch_size, n_mix
probabilities = probs * component_probs # batch_size, n_mix
probabilities = probabilities.sum(dim=1) # batch_size
# Return log probabilities
return torch.log(probabilities + 1e-10)
def get_distribution(self):
"""
Returns the probability distribution over all possible values.
Returns:
numpy array of shape (d,) containing probabilities
"""
device = next(self.parameters()).device
x = torch.arange(self.d, device=device)
with torch.no_grad():
log_probs = self.forward(x)
probs = torch.exp(log_probs)
return (probs / torch.sum(probs)).cpu().numpy()
def q1_b(train_data, test_data, d, dset_id):
"""
Train a mixture of logistics model on discrete data.
"""
# Set random seed and device
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Convert data to tensors
train_tensor = torch.tensor(train_data, dtype=torch.long, device=device)
test_tensor = torch.tensor(test_data, dtype=torch.long, device=device)
# Create dataset and dataloader
train_dataset = TensorDataset(train_tensor)
test_dataset = TensorDataset(test_tensor)
# Set hyperparameters based on dataset ID
hyperparams = {
1: {"batch_size": 800, "lr": 0.005, "num_epochs": 10000},
2: {"batch_size": 800, "lr": 0.001, "num_epochs": 10000}
}[dset_id]
train_loader = DataLoader(
train_dataset,
batch_size=hyperparams["batch_size"],
shuffle=True
)
test_loader = DataLoader(
test_dataset,
batch_size=hyperparams["batch_size"],
shuffle=False
)
train_losses = np.zeros(hyperparams["num_epochs"])
test_losses = np.zeros(hyperparams["num_epochs"])
# Initialize model and optimizer
model = MixtureOfLogistics(d, n_mix=4).to(device)
optimizer = optim.Adam(model.parameters(), lr=hyperparams["lr"])
# TODO: Implement training loop
# 1. Track training and test losses
# 2. For each epoch:
# a. Train the model on batches
# b. Compute and record loss
# c. Evaluate on test set
for epoch in range(hyperparams["num_epochs"]):
model.train()
train_loss = 0.0
for batch in train_loader:
optimizer.zero_grad()
x = batch[0].to(device)
log_probs = model(x)
nll = -log_probs.mean()
nll.backward()
optimizer.step()
train_loss += nll.item()
train_losses[epoch] = train_loss / len(train_loader)
model.eval()
test_loss = 0.0
with torch.no_grad():
for batch in test_loader:
x = batch[0].to(device)
log_probs = model(x)
nll = -log_probs.mean()
test_loss += nll.item()
test_losses[epoch] = test_loss / len(test_loader)
# Get final model probabilities
model.eval()
model_probs = model.get_distribution()
print("len(model_probs)", len(model_probs))
# Return placeholder values for now
return train_losses, test_losses, model_probs
Results¶
Once you've implemented q1_b, execute the cells below to visualize and save your results
q1_save_results(1, 'b', q1_b)
len(model_probs) 20 Final Test Loss: 2.5499
q1_save_results(2, 'b', q1_b)
len(model_probs) 100 Final Test Loss: 4.0082
Question 2 PixelCNNs¶
Now, you will train more powerful PixelCNN models on the shapes dataset and MNIST. In addition, we will extend to modeling colored datasets.
Run the cell below to visualize the two datasets binary datasets
visualize_q2a_data(1)
visualize_q2a_data(2)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data samples shape: (100, 20, 20, 1)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data samples shape: (100, 28, 28, 1)
Part (a) PixelCNN on Shapes and MNIST¶
In this part, implement a simple PixelCNN architecture to model binary MNIST and shapes images (same as Q2(b), but with a PixelCNN).
We recommend the following network design:
- A $7 \times 7$ masked type A convolution
- $5$ $7 \times 7$ masked type B convolutions
- $2$ $1 \times 1$ masked type B convolutions
- Appropriate ReLU nonlinearities in-between
- 64 convolutional filters
And the following hyperparameters:
- Batch size 128
- Learning rate $10^{-3}$
- 10 epochs
- Adam Optimizer (this applies to all PixelCNN models trained in future parts)
Your model should output logits, after which you could apply a sigmoid over 1 logit, or a softmax over two logits (either is fine). It may also help to scale your input to $[-1, 1]$ before running it through the network.
Training on the shapes dataset should be quick, and MNIST should take around 10 minutes
Checkout the Paper for more details: https://arxiv.org/abs/1601.06759
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- 100 samples from the final trained model
Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm
class MaskedConv2d(nn.Conv2d):
"""
Implementation of a masked convolution layer for PixelCNN.
Masks can be of type 'A' or 'B'.
"""
def __init__(self, in_channels, out_channels, kernel_size, mask_type='A', padding='same', **kwargs):
super(MaskedConv2d, self).__init__(in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self.register_buffer('mask', torch.ones_like(self.weight))
self.mask_type = mask_type
# Create mask
h, w = kernel_size, kernel_size
if isinstance(kernel_size, int):
h, w = kernel_size, kernel_size
center_h, center_w = h // 2, w // 2
# For all spatial locations
for i in range(h):
for j in range(w):
# Mask out future pixels (below and to the right)
if (i > center_h) or (i == center_h and j > center_w):
self.mask[:, :, i, j] = 0
# For mask type A, also mask out the center pixel
if mask_type == 'A' and i == center_h and j == center_w:
self.mask[:, :, i, j] = 0
def forward(self, x):
# Apply the mask to weights
self.weight.data *= self.mask
return super(MaskedConv2d, self).forward(x)
class PixelCNN(nn.Module):
"""
PixelCNN model for binary image generation.
As recommended in the assignment:
- One 7x7 masked type A convolution
- Five 7x7 masked type B convolutions
- Two 1x1 masked type B convolutions
- ReLU nonlinearities in-between
- 64 convolutional filters
"""
def __init__(self, in_channels=1, hidden_dim=64):
super(PixelCNN, self).__init__()
# Initial masked convolutional layer of type A
self.conv_a = MaskedConv2d(in_channels, hidden_dim, kernel_size=7, mask_type='A', padding='same')
# Stack of masked convolutional layers of type B
self.conv_b_stack = nn.ModuleList([
MaskedConv2d(hidden_dim, hidden_dim, kernel_size=7, mask_type='B', padding='same')
for _ in range(5)
])
# Final 1x1 convolutions
self.conv_1x1_stack = nn.ModuleList([
MaskedConv2d(hidden_dim, hidden_dim, kernel_size=1, mask_type='B', padding='same')
for _ in range(2)
])
# Output layer: 1 channel for binary output (will apply sigmoid later)
self.output_conv = MaskedConv2d(hidden_dim, 1, kernel_size=1, mask_type='B', padding='same')
def forward(self, x):
# Apply first mask A convolution
x = F.relu(self.conv_a(x))
# Apply mask B convolutions with ReLU
for conv_b in self.conv_b_stack:
x = F.relu(conv_b(x))
# Apply 1x1 convolutions with ReLU
for conv_1x1 in self.conv_1x1_stack:
x = F.relu(conv_1x1(x))
# Final output layer (returns logits)
x = self.output_conv(x)
return x
def sample_from_model(model, image_shape, device, num_samples=100):
"""
Sample images from the trained model using ancestral sampling.
"""
model.eval()
samples = torch.zeros((num_samples, 1, image_shape[0], image_shape[1]), device=device)
with torch.no_grad():
# Generate each pixel sequentially
for i in range(image_shape[0]):
for j in range(image_shape[1]):
# Get the model's prediction
logits = model(samples)[:, :, i, j]
# Convert logits to probabilities
probs = torch.sigmoid(logits)
# Sample from Bernoulli distribution
samples[:, :, i, j] = torch.bernoulli(probs)
return samples.cpu().numpy().transpose(0, 2, 3, 1)
def binary_cross_entropy_loss(logits, targets):
"""
Compute binary cross entropy loss from logits.
"""
return F.binary_cross_entropy_with_logits(logits, targets)
def negative_log_likelihood(logits, targets):
"""
Compute negative log likelihood in nats per dimension.
"""
batch_size = targets.size(0)
n_dims = targets.size(1) * targets.size(2) * targets.size(3)
# Compute binary cross entropy (already in log scale)
bce = binary_cross_entropy_loss(logits, targets)
# Convert to nats (from bits) and normalize by dimensions
# No need for conversion as PyTorch already uses natural log
nll = bce * n_dims
return nll
def q2_a(train_data, test_data, image_shape, dset_id):
"""
train_data: A (n_train, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
test_data: A (n_test, H, W, 1) uint8 numpy array of binary images with values in {0, 1}
image_shape: (H, W), height and width of the image
dset_id: An identifying number of which dataset is given (1 or 2). Most likely
used to set different hyperparameters for different datasets
Returns
- a (# of training iterations,) numpy array of train_losses evaluated every minibatch
- a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
- a numpy array of size (100, H, W, 1) of samples with values in {0, 1}
"""
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
batch_size = 128
learning_rate = 1e-3
num_epochs = 10
# Convert data to PyTorch tensors and scale to [-1, 1]
train_data = torch.from_numpy(train_data).float().permute(0, 3, 1, 2).to(device)
test_data = torch.from_numpy(test_data).float().permute(0, 3, 1, 2).to(device)
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
# Initialize model, optimizer
model = PixelCNN(in_channels=1, hidden_dim=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Lists to store metrics
train_losses = []
test_losses = []
# Initial test loss
model.eval()
total_test_loss = 0
with torch.no_grad():
for data in test_loader:
outputs = model(data)
loss = negative_log_likelihood(outputs, data)
total_test_loss += loss.item()
initial_test_loss = total_test_loss / len(test_loader)
test_losses.append(initial_test_loss)
# Training loop
for epoch in range(num_epochs):
model.train()
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
for batch_idx, data in enumerate(progress_bar):
optimizer.zero_grad()
outputs = model(data)
loss = binary_cross_entropy_loss(outputs, data)
loss.backward()
optimizer.step()
# Track the average NLL per dimension
nll = negative_log_likelihood(outputs, data)
train_losses.append(nll.item())
progress_bar.set_postfix({'Loss': nll.item()})
# Evaluate on test set after each epoch
model.eval()
total_test_loss = 0
with torch.no_grad():
for data in test_loader:
outputs = model(data)
loss = negative_log_likelihood(outputs, data)
total_test_loss += loss.item()
epoch_test_loss = total_test_loss / len(test_loader)
test_losses.append(epoch_test_loss)
print(f'Epoch {epoch+1}: Test Loss: {epoch_test_loss:.6f}')
# Generate samples
samples = sample_from_model(model, image_shape, device, num_samples=100)
samples = (samples > 0.5).astype(np.uint8) # Convert probabilities to binary
return np.array(train_losses), np.array(test_losses), samples
Results¶
Once you've implemented q2_a, execute the cells below to visualize and save your results
q2a_save_results(1, q2_a)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/10: 100%|██████████| 82/82 [00:01<00:00, 69.27it/s, Loss=62.2]
Epoch 1: Test Loss: 63.498842
Epoch 2/10: 100%|██████████| 82/82 [00:01<00:00, 76.21it/s, Loss=53.7]
Epoch 2: Test Loss: 56.255177
Epoch 3/10: 100%|██████████| 82/82 [00:01<00:00, 76.00it/s, Loss=43.2]
Epoch 3: Test Loss: 44.143267
Epoch 4/10: 100%|██████████| 82/82 [00:01<00:00, 75.82it/s, Loss=35.9]
Epoch 4: Test Loss: 35.370099
Epoch 5/10: 100%|██████████| 82/82 [00:01<00:00, 75.84it/s, Loss=31.9]
Epoch 5: Test Loss: 30.587291
Epoch 6/10: 100%|██████████| 82/82 [00:01<00:00, 75.87it/s, Loss=26.7]
Epoch 6: Test Loss: 27.376665
Epoch 7/10: 100%|██████████| 82/82 [00:01<00:00, 75.84it/s, Loss=25.1]
Epoch 7: Test Loss: 26.798818
Epoch 8/10: 100%|██████████| 82/82 [00:01<00:00, 76.25it/s, Loss=24.3]
Epoch 8: Test Loss: 23.934251
Epoch 9/10: 100%|██████████| 82/82 [00:01<00:00, 76.02it/s, Loss=21.8]
Epoch 9: Test Loss: 22.455027
Epoch 10/10: 100%|██████████| 82/82 [00:01<00:00, 76.16it/s, Loss=20.7]
Epoch 10: Test Loss: 20.706674 Final Test Loss: 20.7067
samples shape: (100, 20, 20, 1)
q2a_save_results(2, q2_a)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/10: 100%|██████████| 469/469 [00:10<00:00, 43.04it/s, Loss=68.7]
Epoch 1: Test Loss: 69.982057
Epoch 2/10: 100%|██████████| 469/469 [00:10<00:00, 42.91it/s, Loss=68.3]
Epoch 2: Test Loss: 66.214365
Epoch 3/10: 100%|██████████| 469/469 [00:11<00:00, 42.53it/s, Loss=64.3]
Epoch 3: Test Loss: 65.107751
Epoch 4/10: 100%|██████████| 469/469 [00:11<00:00, 42.28it/s, Loss=62]
Epoch 4: Test Loss: 63.785228
Epoch 5/10: 100%|██████████| 469/469 [00:11<00:00, 42.00it/s, Loss=65.2]
Epoch 5: Test Loss: 63.833508
Epoch 6/10: 100%|██████████| 469/469 [00:11<00:00, 42.03it/s, Loss=62.2]
Epoch 6: Test Loss: 62.760284
Epoch 7/10: 100%|██████████| 469/469 [00:11<00:00, 42.03it/s, Loss=64.5]
Epoch 7: Test Loss: 62.533142
Epoch 8/10: 100%|██████████| 469/469 [00:11<00:00, 42.61it/s, Loss=64.3]
Epoch 8: Test Loss: 62.928753
Epoch 9/10: 100%|██████████| 469/469 [00:10<00:00, 42.89it/s, Loss=61.7]
Epoch 9: Test Loss: 61.973298
Epoch 10/10: 100%|██████████| 469/469 [00:11<00:00, 42.39it/s, Loss=61]
Epoch 10: Test Loss: 61.620466 Final Test Loss: 61.6205
samples shape: (100, 28, 28, 1)
Part (b) PixelCNN on Colored Shapes and MNIST: Independent Color Channels¶
For the next part, we'll work with color images (shapes and MNIST). Run the cell below to visualize the dataset.
visualize_q2b_data(1)
visualize_q2b_data(2)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data samples shape: (100, 20, 20, 3)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data samples shape: (100, 28, 28, 3)
Now, implement a PixelCNN to support RGB color channels (or augment your existing implementation). First, implement a PixelCNN that assumes color channels as independent. More formally, we model the following parameterized distribution:
$$p_\theta(x) = \prod_{i=1}^{HW}\prod_{c=1}^C p_\theta(x_i^c | x_{<i})$$Here are some tips that you may find useful for designing and training these models:
- You will need a 4-way softmax for every prediction, as opposed to a 256-way softmax in the PixelCNN paper, since the dataset is quantized to two bits per color channel
- You can set the number of filters for each convolutions to 120. You can use the ReLU nonlinearity throughout.
- Use a stack of 8 residual block architecture from Figure 5 but with 7 x 7 masked convolutions in the middle instead of 3 x 3 masked convolutions
- Consider using layer normalization to improve performance. However, be careful to maintain the autoregressive property.
- With a learning rate of $10^{-3}$ and a batch size of 128, it should take a few minutes to run on the shapes dataset, and about 50-60 minutes on MNIST.
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- 100 samples from the final trained model
Fill out the function below and return the necessary arguments. Feel free to create more cells if need be.
import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
quiet = False
def train(model, train_loader, optimizer, epoch, grad_clip=None):
model.train()
train_losses = []
for x in train_loader:
x = x.cuda().contiguous()
loss = model.loss(x)
optimizer.zero_grad()
loss.backward()
if grad_clip:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
train_losses.append(loss.item())
return train_losses
def eval_loss(model, data_loader):
model.eval()
total_loss = 0
with torch.no_grad():
for x in data_loader:
x = x.cuda().contiguous()
loss = model.loss(x)
total_loss += loss * x.shape[0]
avg_loss = total_loss / len(data_loader.dataset)
return avg_loss.item()
def train_epochs(model, train_loader, test_loader, train_args):
epochs, lr = train_args['epochs'], train_args['lr']
grad_clip = train_args.get('grad_clip', None)
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses = []
test_losses = [eval_loss(model, test_loader)]
for epoch in range(epochs):
model.train()
train_losses.extend(train(model, train_loader, optimizer, epoch, grad_clip))
test_loss = eval_loss(model, test_loader)
test_losses.append(test_loss)
if not quiet:
print(f'Epoch {epoch}, Test loss {test_loss:.4f}')
return train_losses, test_losses
class Histogram(nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
self.logits = nn.Parameter(torch.zeros(d), requires_grad=True)
def loss(self, x):
logits = self.logits.unsqueeze(0).repeat(x.shape[0], 1) # batch_size x d
return F.cross_entropy(logits, x.long())
def get_distribution(self):
distribution = F.softmax(self.logits, dim=0)
return distribution.detach().cpu().numpy()
class MaskedConv2d(nn.Conv2d):
"""
Implementation of Masked Convolution for PixelCNN
"""
def __init__(self, mask_type, *args, **kwargs):
assert mask_type == 'A' or mask_type == 'B'
super().__init__(*args, **kwargs)
self.register_buffer('mask', torch.zeros_like(self.weight))
self.create_mask(mask_type)
def forward(self, input):
# Apply convolution with mask
out = F.conv2d(input, self.weight * self.mask, self.bias, self.stride,
self.padding, self.dilation, self.groups)
return out
def create_mask(self, mask_type):
# Get kernel size (assuming square kernel)
k = self.kernel_size[0]
# Set mask to 1 for all positions above the center
self.mask[:, :, :k // 2] = 1
# Set mask to 1 for positions to the left of center in the center row
self.mask[:, :, k // 2, :k // 2] = 1
# For type B masks, also set the center pixel to 1
if mask_type == 'B':
self.mask[:, :, k // 2, k // 2] = 1
class LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
x = x.permute(0, 2, 3, 1).contiguous()
# Apply normalization (normalized_shape now matches)
x = super().forward(x)
# Permute back to [batch, channels, height, width]
return x.permute(0, 3, 1, 2).contiguous()
class ResidualBlock(nn.Module):
def __init__(self, n_filters, image_shape=None):
super(ResidualBlock, self).__init__()
# layer nornalization
H, W, _ = image_shape if image_shape is not None else (20, 20, 3)
self.layer_norm1 = LayerNorm(n_filters , H, W)
self.layer_norm2 = LayerNorm(n_filters , H, W)
self.layer_norm3 = LayerNorm(n_filters , H, W)
# Main path
self.conv1 = nn.Conv2d(n_filters, n_filters, kernel_size=1)
self.conv2 = MaskedConv2d('B', n_filters, n_filters, kernel_size=7, padding=3)
self.conv3 = nn.Conv2d(n_filters, n_filters, kernel_size=1)
self.relu = nn.ReLU()
def forward(self, x):
# Store input for the skip connection
identity = x
# Main path
out = self.layer_norm1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.layer_norm2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.layer_norm3(out)
out = self.relu(out)
out = self.conv3(out)
# Skip connection
out += identity
return out
def compute_loss(logits, targets):
"""
Compute cross-entropy loss for the PixelCNN model.
Args:
logits: Tensor of shape [batch_size, C, 4, H, W] - model predictions
targets: Tensor of shape [batch_size, C, H, W] - ground truth values
Returns:
Average cross-entropy loss
"""
logits_reshaped = logits.permute(0, 1, 3, 4, 2).reshape(-1, 4) # [batch_size, C, H, W 4]
targets_reshaped = targets.reshape(-1).long()
return F.cross_entropy(logits_reshaped, targets_reshaped)
def evaluate_model(model, data_loader, device):
"""
Evaluate the model on a dataset
"""
model.eval()
total_loss = 0.0
total_batches = 0
with torch.no_grad():
for (data,) in data_loader:
data = data.to(device)
total_batches += 1
logits = model(data)
loss = compute_loss(logits, data)
total_loss += loss.item()
return total_loss / total_batches
def generate_samples(model, num_samples, image_shape, device):
"""
Generate samples from the model using ancestral sampling.
Assumes the model expects float inputs representing integer values {0, 1, 2, 3}.
"""
model.eval()
H, W, C = image_shape
# Initialize samples tensor with float type, as the model likely expects float inputs
# even though the values represent discrete levels.
samples = torch.zeros(num_samples, C, H, W, dtype=torch.float, device=device)
temperature = 0.6
with torch.no_grad():
for h in range(H):
for w in range(W):
for c in range(C):
logits = model(samples)
pixel_logits = logits[:, c, :, h, w] # Shape: [num_samples, 4]
probs = F.softmax(pixel_logits / temperature, dim=1)
pixel_samples = torch.multinomial(probs, 1).squeeze(-1) # Shape: [num_samples]
samples[:, c, h, w] = pixel_samples.float()
samples_np = samples.cpu().numpy().transpose(0, 2, 3, 1)
samples_np = samples_np.astype(np.uint8)
return samples_np
class PixelCNN(nn.Module):
"""
PixelCNN model with masked convolutions
"""
def __init__(self, image_shape, dset_id):
super(PixelCNN, self).__init__()
self.image_shape = image_shape
self.dset_id = dset_id
self.n_colors = 4
# Number of input channels (1 for MNIST and shapes)
in_channels = 3
# Number of filters as specified in the assignment
n_filters = 120
# First layer: 7x7 masked type A convolution
self.conv_A = MaskedConv2d('A', in_channels, n_filters, 7, padding=3, bias=True)
# 8 layers of 7x7 masked type B convolutions
self.residual_layers = nn.ModuleList([
ResidualBlock(n_filters, image_shape)
for _ in range(8)
])
# 2 layers of 1x1 masked type B convolutions
self.conv_B_1x1_layers = nn.ModuleList([
MaskedConv2d('B', n_filters, n_filters, 1, padding=0, bias=True)
for _ in range(2)
])
# Output 4 logits for each of the 3 color channels
self.output_conv = nn.Conv2d(n_filters, 4*3, 1, padding=0, bias=True)
# ReLU activation
self.relu = nn.ReLU()
def forward(self, x):
# Apply first mask A convolution
x = (x.float() / (self.n_colors - 1) - 0.5) / 0.5
x = self.relu(self.conv_A(x))
# Apply mask B convolutions with ReLU activations
for layer in self.residual_layers:
x = layer(x)
# Apply 1x1 mask B convolutions with ReLU activations
for layer in self.conv_B_1x1_layers:
x = self.relu(layer(x))
# Apply final convolution to get logits
x = self.output_conv(x)
batch_size, _, height, width = x.shape
x = x.view(batch_size, 3,4, height, width)
return x
def q2_b(train_data, test_data, image_shape, dset_id):
"""
Trains a PixelCNN model for RGB images with 4 possible values per channel.
Args:
train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
image_shape: (H, W, C), height, width, and # of channels of the image
dset_id: An identifying number of which dataset is given (1 or 2)
Used to set different hyperparameters for different datasets
Returns:
- train_losses: A (# of training iterations,) numpy array of per-batch training losses
- test_losses: A (# of epochs + 1,) numpy array of test losses after each epoch (including initialization)
- samples: A (100, H, W, C) numpy array of generated samples with values in {0, 1, 2, 3}
"""
# Hyperparameters
batch_size = 128
learning_rate = 0.001 * np.sqrt(batch_size / 128)
num_epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Data preparation
train_data_tensor = torch.FloatTensor(train_data).permute(0, 3, 1, 2)
test_data_tensor = torch.FloatTensor(test_data).permute(0, 3, 1, 2)
train_dataset = torch.utils.data.TensorDataset(train_data_tensor)
test_dataset = torch.utils.data.TensorDataset(test_data_tensor)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=(device.type == 'cuda')
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=(device.type == 'cuda')
)
# Model initialization
model = PixelCNN(image_shape, dset_id).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min', # Reduce LR when the monitored metric stops decreasing
factor=0.5, # Multiply learning rate by this factor when reducing
patience=2, # Number of epochs with no improvement after which LR will be reduced
# verbose=True, # Print message when LR is reduced
min_lr=1e-6 # Lower bound on the learning rate
)
# Initialize gradient scaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()
# Early stopping parameters
best_loss = float('inf')
best_model_state = None
patience = 5
patience_counter = 0
# Loss tracking
train_losses = []
test_losses = []
# Initial evaluation
init_test_loss = evaluate_model(model, test_loader, device)
test_losses.append(init_test_loss)
print(f"Initial test loss: {init_test_loss:.6f}")
# Training loop
for epoch in range(num_epochs):
model.train()
epoch_train_losses = []
for batch_idx, (data,) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
# Use autocast for mixed precision training
with torch.cuda.amp.autocast():
logits = model(data)
loss = compute_loss(logits, data)
# Scale the loss and backpropagate
scaler.scale(loss).backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # Reduced from 1.0 for stability
# Update weights with scaled gradients
scaler.step(optimizer)
scaler.update()
loss_value = loss.item()
train_losses.append(loss_value)
epoch_train_losses.append(loss_value)
if batch_idx % 10 == 0:
print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {batch_idx}/{len(train_loader)}, '
f'Loss: {loss_value:.6f}')
avg_train_loss = sum(epoch_train_losses) / len(epoch_train_losses)
print(f'Epoch {epoch+1} average training loss: {avg_train_loss:.6f}')
# Evaluate the model
test_loss = evaluate_model(model, test_loader, device)
test_losses.append(test_loss)
print(f'Epoch {epoch+1} test loss: {test_loss:.6f}')
# Update learning rate based on test loss
scheduler.step(test_loss)
# Print current learning rate (correct way for ReduceLROnPlateau)
print(f'Current learning rate: {optimizer.param_groups[0]["lr"]:.6f}')
# Early stopping check
if test_loss < best_loss:
best_loss = test_loss
best_model_state = copy.deepcopy(model.state_dict())
patience_counter = 0
print(f"New best model with test loss: {best_loss:.6f}")
else:
patience_counter += 1
print(f"No improvement for {patience_counter} epochs")
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
# Load the best model for sampling
if best_model_state is not None:
model.load_state_dict(best_model_state)
print(f"Loaded best model with test loss: {best_loss:.6f}")
# Generate samples with improved sampling
samples = generate_samples(model, 100, image_shape, device)
return np.array(train_losses), np.array(test_losses), samples
def evaluate_model(model, data_loader, device):
"""
Evaluate the model on a dataset
"""
model.eval()
total_loss = 0.0
total_batches = 0
with torch.no_grad():
for (data,) in data_loader:
data = data.to(device)
total_batches += 1
with torch.cuda.amp.autocast(): # Use autocast for evaluation too
logits = model(data)
loss = compute_loss(logits, data)
total_loss += loss.item()
return total_loss / total_batches
Results¶
Once you've implemented q2_b, execute the cells below to visualize and save your results
q2b_save_results(1, 'b', q2_b)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Using device: cuda
/tmp/ipykernel_2361820/2857294626.py:61: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = torch.cuda.amp.GradScaler()
/tmp/ipykernel_2361820/2857294626.py:162: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(): # Use autocast for evaluation too
Initial test loss: 1.388850
/tmp/ipykernel_2361820/2857294626.py:88: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast():
Epoch: 1/20, Batch: 0/82, Loss: 1.391449 Epoch: 1/20, Batch: 10/82, Loss: 1.329665 Epoch: 1/20, Batch: 20/82, Loss: 1.191522 Epoch: 1/20, Batch: 30/82, Loss: 0.999146 Epoch: 1/20, Batch: 40/82, Loss: 0.806963 Epoch: 1/20, Batch: 50/82, Loss: 0.667744 Epoch: 1/20, Batch: 60/82, Loss: 0.571827 Epoch: 1/20, Batch: 70/82, Loss: 0.516736 Epoch: 1/20, Batch: 80/82, Loss: 0.472714 Epoch 1 average training loss: 0.874104 Epoch 1 test loss: 0.472679 Current learning rate: 0.001000 New best model with test loss: 0.472679 Epoch: 2/20, Batch: 0/82, Loss: 0.465338 Epoch: 2/20, Batch: 10/82, Loss: 0.428478 Epoch: 2/20, Batch: 20/82, Loss: 0.404972 Epoch: 2/20, Batch: 30/82, Loss: 0.361007 Epoch: 2/20, Batch: 40/82, Loss: 0.347785 Epoch: 2/20, Batch: 50/82, Loss: 0.336060 Epoch: 2/20, Batch: 60/82, Loss: 0.321864 Epoch: 2/20, Batch: 70/82, Loss: 0.314387 Epoch: 2/20, Batch: 80/82, Loss: 0.283376 Epoch 2 average training loss: 0.361904 Epoch 2 test loss: 0.286814 Current learning rate: 0.001000 New best model with test loss: 0.286814 Epoch: 3/20, Batch: 0/82, Loss: 0.276777 Epoch: 3/20, Batch: 10/82, Loss: 0.266449 Epoch: 3/20, Batch: 20/82, Loss: 0.255323 Epoch: 3/20, Batch: 30/82, Loss: 0.240444 Epoch: 3/20, Batch: 40/82, Loss: 0.226757 Epoch: 3/20, Batch: 50/82, Loss: 0.205264 Epoch: 3/20, Batch: 60/82, Loss: 0.209912 Epoch: 3/20, Batch: 70/82, Loss: 0.184275 Epoch: 3/20, Batch: 80/82, Loss: 0.179151 Epoch 3 average training loss: 0.226909 Epoch 3 test loss: 0.184995 Current learning rate: 0.001000 New best model with test loss: 0.184995 Epoch: 4/20, Batch: 0/82, Loss: 0.179256 Epoch: 4/20, Batch: 10/82, Loss: 0.172584 Epoch: 4/20, Batch: 20/82, Loss: 0.173224 Epoch: 4/20, Batch: 30/82, Loss: 0.155870 Epoch: 4/20, Batch: 40/82, Loss: 0.151243 Epoch: 4/20, Batch: 50/82, Loss: 0.149168 Epoch: 4/20, Batch: 60/82, Loss: 0.149802 Epoch: 4/20, Batch: 70/82, Loss: 0.127425 Epoch: 4/20, Batch: 80/82, Loss: 0.131133 Epoch 4 average training loss: 0.154246 Epoch 4 test loss: 0.131514 Current learning rate: 0.001000 New best model with test loss: 0.131514 Epoch: 5/20, Batch: 0/82, Loss: 0.127266 Epoch: 5/20, Batch: 10/82, Loss: 0.127639 Epoch: 5/20, Batch: 20/82, Loss: 0.117048 Epoch: 5/20, Batch: 30/82, Loss: 0.126710 Epoch: 5/20, Batch: 40/82, Loss: 0.118495 Epoch: 5/20, Batch: 50/82, Loss: 0.118372 Epoch: 5/20, Batch: 60/82, Loss: 0.114484 Epoch: 5/20, Batch: 70/82, Loss: 0.110210 Epoch: 5/20, Batch: 80/82, Loss: 0.111262 Epoch 5 average training loss: 0.119808 Epoch 5 test loss: 0.112343 Current learning rate: 0.001000 New best model with test loss: 0.112343 Epoch: 6/20, Batch: 0/82, Loss: 0.107689 Epoch: 6/20, Batch: 10/82, Loss: 0.106192 Epoch: 6/20, Batch: 20/82, Loss: 0.107759 Epoch: 6/20, Batch: 30/82, Loss: 0.111071 Epoch: 6/20, Batch: 40/82, Loss: 0.108741 Epoch: 6/20, Batch: 50/82, Loss: 0.108976 Epoch: 6/20, Batch: 60/82, Loss: 0.110752 Epoch: 6/20, Batch: 70/82, Loss: 0.104242 Epoch: 6/20, Batch: 80/82, Loss: 0.105161 Epoch 6 average training loss: 0.107490 Epoch 6 test loss: 0.103943 Current learning rate: 0.001000 New best model with test loss: 0.103943 Epoch: 7/20, Batch: 0/82, Loss: 0.104092 Epoch: 7/20, Batch: 10/82, Loss: 0.107277 Epoch: 7/20, Batch: 20/82, Loss: 0.103453 Epoch: 7/20, Batch: 30/82, Loss: 0.101431 Epoch: 7/20, Batch: 40/82, Loss: 0.101527 Epoch: 7/20, Batch: 50/82, Loss: 0.096972 Epoch: 7/20, Batch: 60/82, Loss: 0.098578 Epoch: 7/20, Batch: 70/82, Loss: 0.099638 Epoch: 7/20, Batch: 80/82, Loss: 0.099156 Epoch 7 average training loss: 0.100878 Epoch 7 test loss: 0.098755 Current learning rate: 0.001000 New best model with test loss: 0.098755 Epoch: 8/20, Batch: 0/82, Loss: 0.097980 Epoch: 8/20, Batch: 10/82, Loss: 0.099518 Epoch: 8/20, Batch: 20/82, Loss: 0.096265 Epoch: 8/20, Batch: 30/82, Loss: 0.101945 Epoch: 8/20, Batch: 40/82, Loss: 0.096464 Epoch: 8/20, Batch: 50/82, Loss: 0.097598 Epoch: 8/20, Batch: 60/82, Loss: 0.096989 Epoch: 8/20, Batch: 70/82, Loss: 0.097894 Epoch: 8/20, Batch: 80/82, Loss: 0.093489 Epoch 8 average training loss: 0.097322 Epoch 8 test loss: 0.096321 Current learning rate: 0.001000 New best model with test loss: 0.096321 Epoch: 9/20, Batch: 0/82, Loss: 0.097640 Epoch: 9/20, Batch: 10/82, Loss: 0.096712 Epoch: 9/20, Batch: 20/82, Loss: 0.099051 Epoch: 9/20, Batch: 30/82, Loss: 0.094936 Epoch: 9/20, Batch: 40/82, Loss: 0.091105 Epoch: 9/20, Batch: 50/82, Loss: 0.088823 Epoch: 9/20, Batch: 60/82, Loss: 0.090040 Epoch: 9/20, Batch: 70/82, Loss: 0.093468 Epoch: 9/20, Batch: 80/82, Loss: 0.096179 Epoch 9 average training loss: 0.094821 Epoch 9 test loss: 0.094621 Current learning rate: 0.001000 New best model with test loss: 0.094621 Epoch: 10/20, Batch: 0/82, Loss: 0.094129 Epoch: 10/20, Batch: 10/82, Loss: 0.094922 Epoch: 10/20, Batch: 20/82, Loss: 0.089785 Epoch: 10/20, Batch: 30/82, Loss: 0.094372 Epoch: 10/20, Batch: 40/82, Loss: 0.089206 Epoch: 10/20, Batch: 50/82, Loss: 0.094630 Epoch: 10/20, Batch: 60/82, Loss: 0.094062 Epoch: 10/20, Batch: 70/82, Loss: 0.089265 Epoch: 10/20, Batch: 80/82, Loss: 0.090081 Epoch 10 average training loss: 0.093107 Epoch 10 test loss: 0.091463 Current learning rate: 0.001000 New best model with test loss: 0.091463 Epoch: 11/20, Batch: 0/82, Loss: 0.090736 Epoch: 11/20, Batch: 10/82, Loss: 0.088854 Epoch: 11/20, Batch: 20/82, Loss: 0.092054 Epoch: 11/20, Batch: 30/82, Loss: 0.092395 Epoch: 11/20, Batch: 40/82, Loss: 0.090031 Epoch: 11/20, Batch: 50/82, Loss: 0.090507 Epoch: 11/20, Batch: 60/82, Loss: 0.094138 Epoch: 11/20, Batch: 70/82, Loss: 0.091632 Epoch: 11/20, Batch: 80/82, Loss: 0.090938 Epoch 11 average training loss: 0.090473 Epoch 11 test loss: 0.089798 Current learning rate: 0.001000 New best model with test loss: 0.089798 Epoch: 12/20, Batch: 0/82, Loss: 0.088649 Epoch: 12/20, Batch: 10/82, Loss: 0.087368 Epoch: 12/20, Batch: 20/82, Loss: 0.083911 Epoch: 12/20, Batch: 30/82, Loss: 0.090002 Epoch: 12/20, Batch: 40/82, Loss: 0.094381 Epoch: 12/20, Batch: 50/82, Loss: 0.085193 Epoch: 12/20, Batch: 60/82, Loss: 0.085929 Epoch: 12/20, Batch: 70/82, Loss: 0.085538 Epoch: 12/20, Batch: 80/82, Loss: 0.087217 Epoch 12 average training loss: 0.088904 Epoch 12 test loss: 0.087783 Current learning rate: 0.001000 New best model with test loss: 0.087783 Epoch: 13/20, Batch: 0/82, Loss: 0.083210 Epoch: 13/20, Batch: 10/82, Loss: 0.090468 Epoch: 13/20, Batch: 20/82, Loss: 0.089873 Epoch: 13/20, Batch: 30/82, Loss: 0.085177 Epoch: 13/20, Batch: 40/82, Loss: 0.085015 Epoch: 13/20, Batch: 50/82, Loss: 0.084958 Epoch: 13/20, Batch: 60/82, Loss: 0.087870 Epoch: 13/20, Batch: 70/82, Loss: 0.085928 Epoch: 13/20, Batch: 80/82, Loss: 0.081580 Epoch 13 average training loss: 0.087305 Epoch 13 test loss: 0.087037 Current learning rate: 0.001000 New best model with test loss: 0.087037 Epoch: 14/20, Batch: 0/82, Loss: 0.084078 Epoch: 14/20, Batch: 10/82, Loss: 0.090743 Epoch: 14/20, Batch: 20/82, Loss: 0.088010 Epoch: 14/20, Batch: 30/82, Loss: 0.084594 Epoch: 14/20, Batch: 40/82, Loss: 0.088533 Epoch: 14/20, Batch: 50/82, Loss: 0.081967 Epoch: 14/20, Batch: 60/82, Loss: 0.087185 Epoch: 14/20, Batch: 70/82, Loss: 0.088961 Epoch: 14/20, Batch: 80/82, Loss: 0.082240 Epoch 14 average training loss: 0.086166 Epoch 14 test loss: 0.086303 Current learning rate: 0.001000 New best model with test loss: 0.086303 Epoch: 15/20, Batch: 0/82, Loss: 0.083607 Epoch: 15/20, Batch: 10/82, Loss: 0.084251 Epoch: 15/20, Batch: 20/82, Loss: 0.089648 Epoch: 15/20, Batch: 30/82, Loss: 0.082188 Epoch: 15/20, Batch: 40/82, Loss: 0.087374 Epoch: 15/20, Batch: 50/82, Loss: 0.087053 Epoch: 15/20, Batch: 60/82, Loss: 0.091182 Epoch: 15/20, Batch: 70/82, Loss: 0.082506 Epoch: 15/20, Batch: 80/82, Loss: 0.092663 Epoch 15 average training loss: 0.086581 Epoch 15 test loss: 0.088489 Current learning rate: 0.001000 No improvement for 1 epochs Epoch: 16/20, Batch: 0/82, Loss: 0.083316 Epoch: 16/20, Batch: 10/82, Loss: 0.083976 Epoch: 16/20, Batch: 20/82, Loss: 0.081897 Epoch: 16/20, Batch: 30/82, Loss: 0.080853 Epoch: 16/20, Batch: 40/82, Loss: 0.082935 Epoch: 16/20, Batch: 50/82, Loss: 0.080993 Epoch: 16/20, Batch: 60/82, Loss: 0.087351 Epoch: 16/20, Batch: 70/82, Loss: 0.080957 Epoch: 16/20, Batch: 80/82, Loss: 0.089425 Epoch 16 average training loss: 0.085030 Epoch 16 test loss: 0.086015 Current learning rate: 0.001000 New best model with test loss: 0.086015 Epoch: 17/20, Batch: 0/82, Loss: 0.082371 Epoch: 17/20, Batch: 10/82, Loss: 0.087590 Epoch: 17/20, Batch: 20/82, Loss: 0.083791 Epoch: 17/20, Batch: 30/82, Loss: 0.080405 Epoch: 17/20, Batch: 40/82, Loss: 0.085472 Epoch: 17/20, Batch: 50/82, Loss: 0.087293 Epoch: 17/20, Batch: 60/82, Loss: 0.085169 Epoch: 17/20, Batch: 70/82, Loss: 0.085010 Epoch: 17/20, Batch: 80/82, Loss: 0.092598 Epoch 17 average training loss: 0.085812 Epoch 17 test loss: 0.086086 Current learning rate: 0.001000 No improvement for 1 epochs Epoch: 18/20, Batch: 0/82, Loss: 0.085155 Epoch: 18/20, Batch: 10/82, Loss: 0.087516 Epoch: 18/20, Batch: 20/82, Loss: 0.084148 Epoch: 18/20, Batch: 30/82, Loss: 0.085064 Epoch: 18/20, Batch: 40/82, Loss: 0.092279 Epoch: 18/20, Batch: 50/82, Loss: 0.081455 Epoch: 18/20, Batch: 60/82, Loss: 0.087196 Epoch: 18/20, Batch: 70/82, Loss: 0.083074 Epoch: 18/20, Batch: 80/82, Loss: 0.083386 Epoch 18 average training loss: 0.085906 Epoch 18 test loss: 0.090479 Current learning rate: 0.001000 No improvement for 2 epochs Epoch: 19/20, Batch: 0/82, Loss: 0.091709 Epoch: 19/20, Batch: 10/82, Loss: 0.083963 Epoch: 19/20, Batch: 20/82, Loss: 0.083727 Epoch: 19/20, Batch: 30/82, Loss: 0.085782 Epoch: 19/20, Batch: 40/82, Loss: 0.083908 Epoch: 19/20, Batch: 50/82, Loss: 0.086676 Epoch: 19/20, Batch: 60/82, Loss: 0.086588 Epoch: 19/20, Batch: 70/82, Loss: 0.082210 Epoch: 19/20, Batch: 80/82, Loss: 0.087303 Epoch 19 average training loss: 0.085502 Epoch 19 test loss: 0.085137 Current learning rate: 0.001000 New best model with test loss: 0.085137 Epoch: 20/20, Batch: 0/82, Loss: 0.086982 Epoch: 20/20, Batch: 10/82, Loss: 0.088542 Epoch: 20/20, Batch: 20/82, Loss: 0.090561 Epoch: 20/20, Batch: 30/82, Loss: 0.078148 Epoch: 20/20, Batch: 40/82, Loss: 0.080311 Epoch: 20/20, Batch: 50/82, Loss: 0.090522 Epoch: 20/20, Batch: 60/82, Loss: 0.081093 Epoch: 20/20, Batch: 70/82, Loss: 0.078895 Epoch: 20/20, Batch: 80/82, Loss: 0.088900 Epoch 20 average training loss: 0.085175 Epoch 20 test loss: 0.084870 Current learning rate: 0.001000 New best model with test loss: 0.084870 Loaded best model with test loss: 0.084870 Final Test Loss: 0.0849
samples shape: (100, 20, 20, 3)
q2b_save_results(2, 'b', q2_b)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
Using device: cuda
/tmp/ipykernel_2361820/2857294626.py:61: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
scaler = torch.cuda.amp.GradScaler()
/tmp/ipykernel_2361820/2857294626.py:162: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(): # Use autocast for evaluation too
Initial test loss: 1.365160
/tmp/ipykernel_2361820/2857294626.py:88: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast():
Epoch: 1/20, Batch: 0/469, Loss: 1.364236 Epoch: 1/20, Batch: 10/469, Loss: 1.236289 Epoch: 1/20, Batch: 20/469, Loss: 0.936831 Epoch: 1/20, Batch: 30/469, Loss: 0.858414 Epoch: 1/20, Batch: 40/469, Loss: 0.776309 Epoch: 1/20, Batch: 50/469, Loss: 0.680986 Epoch: 1/20, Batch: 60/469, Loss: 0.598954 Epoch: 1/20, Batch: 70/469, Loss: 0.516625 Epoch: 1/20, Batch: 80/469, Loss: 0.491731 Epoch: 1/20, Batch: 90/469, Loss: 0.435464 Epoch: 1/20, Batch: 100/469, Loss: 0.415777 Epoch: 1/20, Batch: 110/469, Loss: 0.375786 Epoch: 1/20, Batch: 120/469, Loss: 0.360483 Epoch: 1/20, Batch: 130/469, Loss: 0.322623 Epoch: 1/20, Batch: 140/469, Loss: 0.319569 Epoch: 1/20, Batch: 150/469, Loss: 0.290041 Epoch: 1/20, Batch: 160/469, Loss: 0.281337 Epoch: 1/20, Batch: 170/469, Loss: 0.275442 Epoch: 1/20, Batch: 180/469, Loss: 0.276323 Epoch: 1/20, Batch: 190/469, Loss: 0.266604 Epoch: 1/20, Batch: 200/469, Loss: 0.266382 Epoch: 1/20, Batch: 210/469, Loss: 0.264654 Epoch: 1/20, Batch: 220/469, Loss: 0.260019 Epoch: 1/20, Batch: 230/469, Loss: 0.266150 Epoch: 1/20, Batch: 240/469, Loss: 0.273089 Epoch: 1/20, Batch: 250/469, Loss: 0.256204 Epoch: 1/20, Batch: 260/469, Loss: 0.263772 Epoch: 1/20, Batch: 270/469, Loss: 0.261555 Epoch: 1/20, Batch: 280/469, Loss: 0.258613 Epoch: 1/20, Batch: 290/469, Loss: 0.263624 Epoch: 1/20, Batch: 300/469, Loss: 0.269133 Epoch: 1/20, Batch: 310/469, Loss: 0.263299 Epoch: 1/20, Batch: 320/469, Loss: 0.261412 Epoch: 1/20, Batch: 330/469, Loss: 0.253682 Epoch: 1/20, Batch: 340/469, Loss: 0.261186 Epoch: 1/20, Batch: 350/469, Loss: 0.261738 Epoch: 1/20, Batch: 360/469, Loss: 0.253270 Epoch: 1/20, Batch: 370/469, Loss: 0.254910 Epoch: 1/20, Batch: 380/469, Loss: 0.251309 Epoch: 1/20, Batch: 390/469, Loss: 0.250324 Epoch: 1/20, Batch: 400/469, Loss: 0.249594 Epoch: 1/20, Batch: 410/469, Loss: 0.262300 Epoch: 1/20, Batch: 420/469, Loss: 0.249880 Epoch: 1/20, Batch: 430/469, Loss: 0.239845 Epoch: 1/20, Batch: 440/469, Loss: 0.252041 Epoch: 1/20, Batch: 450/469, Loss: 0.255249 Epoch: 1/20, Batch: 460/469, Loss: 0.263435 Epoch 1 average training loss: 0.374674 Epoch 1 test loss: 0.251673 Current learning rate: 0.001000 New best model with test loss: 0.251673 Epoch: 2/20, Batch: 0/469, Loss: 0.255822 Epoch: 2/20, Batch: 10/469, Loss: 0.259323 Epoch: 2/20, Batch: 20/469, Loss: 0.267333 Epoch: 2/20, Batch: 30/469, Loss: 0.253701 Epoch: 2/20, Batch: 40/469, Loss: 0.259780 Epoch: 2/20, Batch: 50/469, Loss: 0.250533 Epoch: 2/20, Batch: 60/469, Loss: 0.252443 Epoch: 2/20, Batch: 70/469, Loss: 0.266876 Epoch: 2/20, Batch: 80/469, Loss: 0.235499 Epoch: 2/20, Batch: 90/469, Loss: 0.248745 Epoch: 2/20, Batch: 100/469, Loss: 0.244379 Epoch: 2/20, Batch: 110/469, Loss: 0.249996 Epoch: 2/20, Batch: 120/469, Loss: 0.246507 Epoch: 2/20, Batch: 130/469, Loss: 0.249283 Epoch: 2/20, Batch: 140/469, Loss: 0.253623 Epoch: 2/20, Batch: 150/469, Loss: 0.249731 Epoch: 2/20, Batch: 160/469, Loss: 0.252728 Epoch: 2/20, Batch: 170/469, Loss: 0.245919 Epoch: 2/20, Batch: 180/469, Loss: 0.246375 Epoch: 2/20, Batch: 190/469, Loss: 0.253191 Epoch: 2/20, Batch: 200/469, Loss: 0.246197 Epoch: 2/20, Batch: 210/469, Loss: 0.248491 Epoch: 2/20, Batch: 220/469, Loss: 0.246431 Epoch: 2/20, Batch: 230/469, Loss: 0.248911 Epoch: 2/20, Batch: 240/469, Loss: 0.245220 Epoch: 2/20, Batch: 250/469, Loss: 0.246489 Epoch: 2/20, Batch: 260/469, Loss: 0.249481 Epoch: 2/20, Batch: 270/469, Loss: 0.246066 Epoch: 2/20, Batch: 280/469, Loss: 0.246358 Epoch: 2/20, Batch: 290/469, Loss: 0.247223 Epoch: 2/20, Batch: 300/469, Loss: 0.249195 Epoch: 2/20, Batch: 310/469, Loss: 0.247354 Epoch: 2/20, Batch: 320/469, Loss: 0.254715 Epoch: 2/20, Batch: 330/469, Loss: 0.242449 Epoch: 2/20, Batch: 340/469, Loss: 0.244332 Epoch: 2/20, Batch: 350/469, Loss: 0.248880 Epoch: 2/20, Batch: 360/469, Loss: 0.242313 Epoch: 2/20, Batch: 370/469, Loss: 0.233009 Epoch: 2/20, Batch: 380/469, Loss: 0.240140 Epoch: 2/20, Batch: 390/469, Loss: 0.231701 Epoch: 2/20, Batch: 400/469, Loss: 0.239487 Epoch: 2/20, Batch: 410/469, Loss: 0.244440 Epoch: 2/20, Batch: 420/469, Loss: 0.241557 Epoch: 2/20, Batch: 430/469, Loss: 0.247809 Epoch: 2/20, Batch: 440/469, Loss: 0.244429 Epoch: 2/20, Batch: 450/469, Loss: 0.236603 Epoch: 2/20, Batch: 460/469, Loss: 0.243593 Epoch 2 average training loss: 0.247342 Epoch 2 test loss: 0.241083 Current learning rate: 0.001000 New best model with test loss: 0.241083 Epoch: 3/20, Batch: 0/469, Loss: 0.239596 Epoch: 3/20, Batch: 10/469, Loss: 0.248777 Epoch: 3/20, Batch: 20/469, Loss: 0.245636 Epoch: 3/20, Batch: 30/469, Loss: 0.244686 Epoch: 3/20, Batch: 40/469, Loss: 0.239712 Epoch: 3/20, Batch: 50/469, Loss: 0.239452 Epoch: 3/20, Batch: 60/469, Loss: 0.244311 Epoch: 3/20, Batch: 70/469, Loss: 0.246592 Epoch: 3/20, Batch: 80/469, Loss: 0.230436 Epoch: 3/20, Batch: 90/469, Loss: 0.232847 Epoch: 3/20, Batch: 100/469, Loss: 0.238146 Epoch: 3/20, Batch: 110/469, Loss: 0.247689 Epoch: 3/20, Batch: 120/469, Loss: 0.239976 Epoch: 3/20, Batch: 130/469, Loss: 0.240102 Epoch: 3/20, Batch: 140/469, Loss: 0.236300 Epoch: 3/20, Batch: 150/469, Loss: 0.237734 Epoch: 3/20, Batch: 160/469, Loss: 0.231506 Epoch: 3/20, Batch: 170/469, Loss: 0.233169 Epoch: 3/20, Batch: 180/469, Loss: 0.241323 Epoch: 3/20, Batch: 190/469, Loss: 0.239814 Epoch: 3/20, Batch: 200/469, Loss: 0.240700 Epoch: 3/20, Batch: 210/469, Loss: 0.236182 Epoch: 3/20, Batch: 220/469, Loss: 0.237793 Epoch: 3/20, Batch: 230/469, Loss: 0.234612 Epoch: 3/20, Batch: 240/469, Loss: 0.234325 Epoch: 3/20, Batch: 250/469, Loss: 0.230192 Epoch: 3/20, Batch: 260/469, Loss: 0.240093 Epoch: 3/20, Batch: 270/469, Loss: 0.234376 Epoch: 3/20, Batch: 280/469, Loss: 0.233690 Epoch: 3/20, Batch: 290/469, Loss: 0.233709 Epoch: 3/20, Batch: 300/469, Loss: 0.239333 Epoch: 3/20, Batch: 310/469, Loss: 0.232038 Epoch: 3/20, Batch: 320/469, Loss: 0.230965 Epoch: 3/20, Batch: 330/469, Loss: 0.234805 Epoch: 3/20, Batch: 340/469, Loss: 0.229885 Epoch: 3/20, Batch: 350/469, Loss: 0.233624 Epoch: 3/20, Batch: 360/469, Loss: 0.236268 Epoch: 3/20, Batch: 370/469, Loss: 0.234259 Epoch: 3/20, Batch: 380/469, Loss: 0.229853 Epoch: 3/20, Batch: 390/469, Loss: 0.235964 Epoch: 3/20, Batch: 400/469, Loss: 0.227787 Epoch: 3/20, Batch: 410/469, Loss: 0.240188 Epoch: 3/20, Batch: 420/469, Loss: 0.226364 Epoch: 3/20, Batch: 430/469, Loss: 0.234578 Epoch: 3/20, Batch: 440/469, Loss: 0.233978 Epoch: 3/20, Batch: 450/469, Loss: 0.233747 Epoch: 3/20, Batch: 460/469, Loss: 0.238040 Epoch 3 average training loss: 0.236651 Epoch 3 test loss: 0.230761 Current learning rate: 0.001000 New best model with test loss: 0.230761 Epoch: 4/20, Batch: 0/469, Loss: 0.234140 Epoch: 4/20, Batch: 10/469, Loss: 0.232589 Epoch: 4/20, Batch: 20/469, Loss: 0.231800 Epoch: 4/20, Batch: 30/469, Loss: 0.235732 Epoch: 4/20, Batch: 40/469, Loss: 0.229330 Epoch: 4/20, Batch: 50/469, Loss: 0.231006 Epoch: 4/20, Batch: 60/469, Loss: 0.224475 Epoch: 4/20, Batch: 70/469, Loss: 0.226936 Epoch: 4/20, Batch: 80/469, Loss: 0.233463 Epoch: 4/20, Batch: 90/469, Loss: 0.235182 Epoch: 4/20, Batch: 100/469, Loss: 0.233420 Epoch: 4/20, Batch: 110/469, Loss: 0.233781 Epoch: 4/20, Batch: 120/469, Loss: 0.228672 Epoch: 4/20, Batch: 130/469, Loss: 0.235871 Epoch: 4/20, Batch: 140/469, Loss: 0.223315 Epoch: 4/20, Batch: 150/469, Loss: 0.232540 Epoch: 4/20, Batch: 160/469, Loss: 0.222886 Epoch: 4/20, Batch: 170/469, Loss: 0.229659 Epoch: 4/20, Batch: 180/469, Loss: 0.227260 Epoch: 4/20, Batch: 190/469, Loss: 0.229968 Epoch: 4/20, Batch: 200/469, Loss: 0.221816 Epoch: 4/20, Batch: 210/469, Loss: 0.230529 Epoch: 4/20, Batch: 220/469, Loss: 0.221697 Epoch: 4/20, Batch: 230/469, Loss: 0.234514 Epoch: 4/20, Batch: 240/469, Loss: 0.226947 Epoch: 4/20, Batch: 250/469, Loss: 0.230282 Epoch: 4/20, Batch: 260/469, Loss: 0.236895 Epoch: 4/20, Batch: 270/469, Loss: 0.228116 Epoch: 4/20, Batch: 280/469, Loss: 0.224220 Epoch: 4/20, Batch: 290/469, Loss: 0.227844 Epoch: 4/20, Batch: 300/469, Loss: 0.223508 Epoch: 4/20, Batch: 310/469, Loss: 0.221039 Epoch: 4/20, Batch: 320/469, Loss: 0.228566 Epoch: 4/20, Batch: 330/469, Loss: 0.229608 Epoch: 4/20, Batch: 340/469, Loss: 0.229945 Epoch: 4/20, Batch: 350/469, Loss: 0.226479 Epoch: 4/20, Batch: 360/469, Loss: 0.220995 Epoch: 4/20, Batch: 370/469, Loss: 0.224024 Epoch: 4/20, Batch: 380/469, Loss: 0.228491 Epoch: 4/20, Batch: 390/469, Loss: 0.221185 Epoch: 4/20, Batch: 400/469, Loss: 0.226422 Epoch: 4/20, Batch: 410/469, Loss: 0.231617 Epoch: 4/20, Batch: 420/469, Loss: 0.217751 Epoch: 4/20, Batch: 430/469, Loss: 0.227823 Epoch: 4/20, Batch: 440/469, Loss: 0.220278 Epoch: 4/20, Batch: 450/469, Loss: 0.230993 Epoch: 4/20, Batch: 460/469, Loss: 0.229055 Epoch 4 average training loss: 0.228498 Epoch 4 test loss: 0.224772 Current learning rate: 0.001000 New best model with test loss: 0.224772 Epoch: 5/20, Batch: 0/469, Loss: 0.224118 Epoch: 5/20, Batch: 10/469, Loss: 0.232740 Epoch: 5/20, Batch: 20/469, Loss: 0.220216 Epoch: 5/20, Batch: 30/469, Loss: 0.228610 Epoch: 5/20, Batch: 40/469, Loss: 0.223448 Epoch: 5/20, Batch: 50/469, Loss: 0.220628 Epoch: 5/20, Batch: 60/469, Loss: 0.231074 Epoch: 5/20, Batch: 70/469, Loss: 0.224305 Epoch: 5/20, Batch: 80/469, Loss: 0.220817 Epoch: 5/20, Batch: 90/469, Loss: 0.236191 Epoch: 5/20, Batch: 100/469, Loss: 0.223346 Epoch: 5/20, Batch: 110/469, Loss: 0.222555 Epoch: 5/20, Batch: 120/469, Loss: 0.218005 Epoch: 5/20, Batch: 130/469, Loss: 0.226730 Epoch: 5/20, Batch: 140/469, Loss: 0.214425 Epoch: 5/20, Batch: 150/469, Loss: 0.230344 Epoch: 5/20, Batch: 160/469, Loss: 0.214178 Epoch: 5/20, Batch: 170/469, Loss: 0.221805 Epoch: 5/20, Batch: 180/469, Loss: 0.218111 Epoch: 5/20, Batch: 190/469, Loss: 0.230607 Epoch: 5/20, Batch: 200/469, Loss: 0.222635 Epoch: 5/20, Batch: 210/469, Loss: 0.222057 Epoch: 5/20, Batch: 220/469, Loss: 0.222385 Epoch: 5/20, Batch: 230/469, Loss: 0.226100 Epoch: 5/20, Batch: 240/469, Loss: 0.229697 Epoch: 5/20, Batch: 250/469, Loss: 0.223248 Epoch: 5/20, Batch: 260/469, Loss: 0.227900 Epoch: 5/20, Batch: 270/469, Loss: 0.225837 Epoch: 5/20, Batch: 280/469, Loss: 0.217221 Epoch: 5/20, Batch: 290/469, Loss: 0.221690 Epoch: 5/20, Batch: 300/469, Loss: 0.228511 Epoch: 5/20, Batch: 310/469, Loss: 0.222221 Epoch: 5/20, Batch: 320/469, Loss: 0.218108 Epoch: 5/20, Batch: 330/469, Loss: 0.225754 Epoch: 5/20, Batch: 340/469, Loss: 0.225743 Epoch: 5/20, Batch: 350/469, Loss: 0.226366 Epoch: 5/20, Batch: 360/469, Loss: 0.222165 Epoch: 5/20, Batch: 370/469, Loss: 0.217813 Epoch: 5/20, Batch: 380/469, Loss: 0.230557 Epoch: 5/20, Batch: 390/469, Loss: 0.224621 Epoch: 5/20, Batch: 400/469, Loss: 0.217610 Epoch: 5/20, Batch: 410/469, Loss: 0.226195 Epoch: 5/20, Batch: 420/469, Loss: 0.216146 Epoch: 5/20, Batch: 430/469, Loss: 0.221329 Epoch: 5/20, Batch: 440/469, Loss: 0.223935 Epoch: 5/20, Batch: 450/469, Loss: 0.223515 Epoch: 5/20, Batch: 460/469, Loss: 0.218860 Epoch 5 average training loss: 0.223410 Epoch 5 test loss: 0.221102 Current learning rate: 0.001000 New best model with test loss: 0.221102 Epoch: 6/20, Batch: 0/469, Loss: 0.217836 Epoch: 6/20, Batch: 10/469, Loss: 0.225832 Epoch: 6/20, Batch: 20/469, Loss: 0.221874 Epoch: 6/20, Batch: 30/469, Loss: 0.213838 Epoch: 6/20, Batch: 40/469, Loss: 0.223980 Epoch: 6/20, Batch: 50/469, Loss: 0.216957 Epoch: 6/20, Batch: 60/469, Loss: 0.224077 Epoch: 6/20, Batch: 70/469, Loss: 0.219094 Epoch: 6/20, Batch: 80/469, Loss: 0.224078 Epoch: 6/20, Batch: 90/469, Loss: 0.216071 Epoch: 6/20, Batch: 100/469, Loss: 0.220423 Epoch: 6/20, Batch: 110/469, Loss: 0.218762 Epoch: 6/20, Batch: 120/469, Loss: 0.224305 Epoch: 6/20, Batch: 130/469, Loss: 0.219479 Epoch: 6/20, Batch: 140/469, Loss: 0.216908 Epoch: 6/20, Batch: 150/469, Loss: 0.210678 Epoch: 6/20, Batch: 160/469, Loss: 0.226148 Epoch: 6/20, Batch: 170/469, Loss: 0.220384 Epoch: 6/20, Batch: 180/469, Loss: 0.222741 Epoch: 6/20, Batch: 190/469, Loss: 0.224896 Epoch: 6/20, Batch: 200/469, Loss: 0.222351 Epoch: 6/20, Batch: 210/469, Loss: 0.225332 Epoch: 6/20, Batch: 220/469, Loss: 0.219862 Epoch: 6/20, Batch: 230/469, Loss: 0.227074 Epoch: 6/20, Batch: 240/469, Loss: 0.223254 Epoch: 6/20, Batch: 250/469, Loss: 0.218581 Epoch: 6/20, Batch: 260/469, Loss: 0.222126 Epoch: 6/20, Batch: 270/469, Loss: 0.218980 Epoch: 6/20, Batch: 280/469, Loss: 0.215863 Epoch: 6/20, Batch: 290/469, Loss: 0.218395 Epoch: 6/20, Batch: 300/469, Loss: 0.218484 Epoch: 6/20, Batch: 310/469, Loss: 0.216774 Epoch: 6/20, Batch: 320/469, Loss: 0.223385 Epoch: 6/20, Batch: 330/469, Loss: 0.222342 Epoch: 6/20, Batch: 340/469, Loss: 0.221791 Epoch: 6/20, Batch: 350/469, Loss: 0.222638 Epoch: 6/20, Batch: 360/469, Loss: 0.221570 Epoch: 6/20, Batch: 370/469, Loss: 0.226992 Epoch: 6/20, Batch: 380/469, Loss: 0.218076 Epoch: 6/20, Batch: 390/469, Loss: 0.228809 Epoch: 6/20, Batch: 400/469, Loss: 0.225015 Epoch: 6/20, Batch: 410/469, Loss: 0.217947 Epoch: 6/20, Batch: 420/469, Loss: 0.218748 Epoch: 6/20, Batch: 430/469, Loss: 0.218244 Epoch: 6/20, Batch: 440/469, Loss: 0.217017 Epoch: 6/20, Batch: 450/469, Loss: 0.221016 Epoch: 6/20, Batch: 460/469, Loss: 0.216728 Epoch 6 average training loss: 0.220853 Epoch 6 test loss: 0.220251 Current learning rate: 0.001000 New best model with test loss: 0.220251 Epoch: 7/20, Batch: 0/469, Loss: 0.224952 Epoch: 7/20, Batch: 10/469, Loss: 0.216069 Epoch: 7/20, Batch: 20/469, Loss: 0.225738 Epoch: 7/20, Batch: 30/469, Loss: 0.218473 Epoch: 7/20, Batch: 40/469, Loss: 0.208255 Epoch: 7/20, Batch: 50/469, Loss: 0.219465 Epoch: 7/20, Batch: 60/469, Loss: 0.216269 Epoch: 7/20, Batch: 70/469, Loss: 0.219640 Epoch: 7/20, Batch: 80/469, Loss: 0.213736 Epoch: 7/20, Batch: 90/469, Loss: 0.223662 Epoch: 7/20, Batch: 100/469, Loss: 0.216803 Epoch: 7/20, Batch: 110/469, Loss: 0.222347 Epoch: 7/20, Batch: 120/469, Loss: 0.223461 Epoch: 7/20, Batch: 130/469, Loss: 0.217972 Epoch: 7/20, Batch: 140/469, Loss: 0.219932 Epoch: 7/20, Batch: 150/469, Loss: 0.224264 Epoch: 7/20, Batch: 160/469, Loss: 0.218657 Epoch: 7/20, Batch: 170/469, Loss: 0.222183 Epoch: 7/20, Batch: 180/469, Loss: 0.229374 Epoch: 7/20, Batch: 190/469, Loss: 0.216462 Epoch: 7/20, Batch: 200/469, Loss: 0.212875 Epoch: 7/20, Batch: 210/469, Loss: 0.225409 Epoch: 7/20, Batch: 220/469, Loss: 0.224333 Epoch: 7/20, Batch: 230/469, Loss: 0.219892 Epoch: 7/20, Batch: 240/469, Loss: 0.218604 Epoch: 7/20, Batch: 250/469, Loss: 0.219289 Epoch: 7/20, Batch: 260/469, Loss: 0.211656 Epoch: 7/20, Batch: 270/469, Loss: 0.225197 Epoch: 7/20, Batch: 280/469, Loss: 0.210968 Epoch: 7/20, Batch: 290/469, Loss: 0.216885 Epoch: 7/20, Batch: 300/469, Loss: 0.214910 Epoch: 7/20, Batch: 310/469, Loss: 0.214594 Epoch: 7/20, Batch: 320/469, Loss: 0.213676 Epoch: 7/20, Batch: 330/469, Loss: 0.217838 Epoch: 7/20, Batch: 340/469, Loss: 0.213745 Epoch: 7/20, Batch: 350/469, Loss: 0.221465 Epoch: 7/20, Batch: 360/469, Loss: 0.219727 Epoch: 7/20, Batch: 370/469, Loss: 0.214327 Epoch: 7/20, Batch: 380/469, Loss: 0.215226 Epoch: 7/20, Batch: 390/469, Loss: 0.213002 Epoch: 7/20, Batch: 400/469, Loss: 0.218513 Epoch: 7/20, Batch: 410/469, Loss: 0.212558 Epoch: 7/20, Batch: 420/469, Loss: 0.218478 Epoch: 7/20, Batch: 430/469, Loss: 0.210598 Epoch: 7/20, Batch: 440/469, Loss: 0.219975 Epoch: 7/20, Batch: 450/469, Loss: 0.216140 Epoch: 7/20, Batch: 460/469, Loss: 0.212611 Epoch 7 average training loss: 0.219044 Epoch 7 test loss: 0.216958 Current learning rate: 0.001000 New best model with test loss: 0.216958 Epoch: 8/20, Batch: 0/469, Loss: 0.223242 Epoch: 8/20, Batch: 10/469, Loss: 0.222034 Epoch: 8/20, Batch: 20/469, Loss: 0.216805 Epoch: 8/20, Batch: 30/469, Loss: 0.219820 Epoch: 8/20, Batch: 40/469, Loss: 0.217796 Epoch: 8/20, Batch: 50/469, Loss: 0.217689 Epoch: 8/20, Batch: 60/469, Loss: 0.224524 Epoch: 8/20, Batch: 70/469, Loss: 0.215307 Epoch: 8/20, Batch: 80/469, Loss: 0.220199 Epoch: 8/20, Batch: 90/469, Loss: 0.208169 Epoch: 8/20, Batch: 100/469, Loss: 0.222147 Epoch: 8/20, Batch: 110/469, Loss: 0.221772 Epoch: 8/20, Batch: 120/469, Loss: 0.215154 Epoch: 8/20, Batch: 130/469, Loss: 0.219635 Epoch: 8/20, Batch: 140/469, Loss: 0.222720 Epoch: 8/20, Batch: 150/469, Loss: 0.212889 Epoch: 8/20, Batch: 160/469, Loss: 0.210679 Epoch: 8/20, Batch: 170/469, Loss: 0.221016 Epoch: 8/20, Batch: 180/469, Loss: 0.211662 Epoch: 8/20, Batch: 190/469, Loss: 0.216154 Epoch: 8/20, Batch: 200/469, Loss: 0.211343 Epoch: 8/20, Batch: 210/469, Loss: 0.218122 Epoch: 8/20, Batch: 220/469, Loss: 0.221368 Epoch: 8/20, Batch: 230/469, Loss: 0.218594 Epoch: 8/20, Batch: 240/469, Loss: 0.224074 Epoch: 8/20, Batch: 250/469, Loss: 0.220542 Epoch: 8/20, Batch: 260/469, Loss: 0.223187 Epoch: 8/20, Batch: 270/469, Loss: 0.218089 Epoch: 8/20, Batch: 280/469, Loss: 0.219223 Epoch: 8/20, Batch: 290/469, Loss: 0.217609 Epoch: 8/20, Batch: 300/469, Loss: 0.214193 Epoch: 8/20, Batch: 310/469, Loss: 0.215090 Epoch: 8/20, Batch: 320/469, Loss: 0.217212 Epoch: 8/20, Batch: 330/469, Loss: 0.215767 Epoch: 8/20, Batch: 340/469, Loss: 0.213468 Epoch: 8/20, Batch: 350/469, Loss: 0.212441 Epoch: 8/20, Batch: 360/469, Loss: 0.213226 Epoch: 8/20, Batch: 370/469, Loss: 0.220515 Epoch: 8/20, Batch: 380/469, Loss: 0.214867 Epoch: 8/20, Batch: 390/469, Loss: 0.213904 Epoch: 8/20, Batch: 400/469, Loss: 0.213580 Epoch: 8/20, Batch: 410/469, Loss: 0.217249 Epoch: 8/20, Batch: 420/469, Loss: 0.221295 Epoch: 8/20, Batch: 430/469, Loss: 0.220053 Epoch: 8/20, Batch: 440/469, Loss: 0.221695 Epoch: 8/20, Batch: 450/469, Loss: 0.212456 Epoch: 8/20, Batch: 460/469, Loss: 0.213071 Epoch 8 average training loss: 0.217465 Epoch 8 test loss: 0.215530 Current learning rate: 0.001000 New best model with test loss: 0.215530 Epoch: 9/20, Batch: 0/469, Loss: 0.220720 Epoch: 9/20, Batch: 10/469, Loss: 0.212772 Epoch: 9/20, Batch: 20/469, Loss: 0.213314 Epoch: 9/20, Batch: 30/469, Loss: 0.214651 Epoch: 9/20, Batch: 40/469, Loss: 0.219689 Epoch: 9/20, Batch: 50/469, Loss: 0.212246 Epoch: 9/20, Batch: 60/469, Loss: 0.217186 Epoch: 9/20, Batch: 70/469, Loss: 0.217723 Epoch: 9/20, Batch: 80/469, Loss: 0.213060 Epoch: 9/20, Batch: 90/469, Loss: 0.209208 Epoch: 9/20, Batch: 100/469, Loss: 0.223697 Epoch: 9/20, Batch: 110/469, Loss: 0.214186 Epoch: 9/20, Batch: 120/469, Loss: 0.212217 Epoch: 9/20, Batch: 130/469, Loss: 0.218618 Epoch: 9/20, Batch: 140/469, Loss: 0.209719 Epoch: 9/20, Batch: 150/469, Loss: 0.225030 Epoch: 9/20, Batch: 160/469, Loss: 0.210015 Epoch: 9/20, Batch: 170/469, Loss: 0.221006 Epoch: 9/20, Batch: 180/469, Loss: 0.213555 Epoch: 9/20, Batch: 190/469, Loss: 0.221309 Epoch: 9/20, Batch: 200/469, Loss: 0.221660 Epoch: 9/20, Batch: 210/469, Loss: 0.211858 Epoch: 9/20, Batch: 220/469, Loss: 0.213579 Epoch: 9/20, Batch: 230/469, Loss: 0.212624 Epoch: 9/20, Batch: 240/469, Loss: 0.218029 Epoch: 9/20, Batch: 250/469, Loss: 0.207524 Epoch: 9/20, Batch: 260/469, Loss: 0.221810 Epoch: 9/20, Batch: 270/469, Loss: 0.214967 Epoch: 9/20, Batch: 280/469, Loss: 0.213709 Epoch: 9/20, Batch: 290/469, Loss: 0.211457 Epoch: 9/20, Batch: 300/469, Loss: 0.212359 Epoch: 9/20, Batch: 310/469, Loss: 0.215434 Epoch: 9/20, Batch: 320/469, Loss: 0.211776 Epoch: 9/20, Batch: 330/469, Loss: 0.217738 Epoch: 9/20, Batch: 340/469, Loss: 0.209256 Epoch: 9/20, Batch: 350/469, Loss: 0.223196 Epoch: 9/20, Batch: 360/469, Loss: 0.214401 Epoch: 9/20, Batch: 370/469, Loss: 0.214568 Epoch: 9/20, Batch: 380/469, Loss: 0.215216 Epoch: 9/20, Batch: 390/469, Loss: 0.213957 Epoch: 9/20, Batch: 400/469, Loss: 0.209357 Epoch: 9/20, Batch: 410/469, Loss: 0.219765 Epoch: 9/20, Batch: 420/469, Loss: 0.217457 Epoch: 9/20, Batch: 430/469, Loss: 0.215437 Epoch: 9/20, Batch: 440/469, Loss: 0.204476 Epoch: 9/20, Batch: 450/469, Loss: 0.215620 Epoch: 9/20, Batch: 460/469, Loss: 0.215351 Epoch 9 average training loss: 0.215185 Epoch 9 test loss: 0.212113 Current learning rate: 0.001000 New best model with test loss: 0.212113 Epoch: 10/20, Batch: 0/469, Loss: 0.217446 Epoch: 10/20, Batch: 10/469, Loss: 0.204490 Epoch: 10/20, Batch: 20/469, Loss: 0.208787 Epoch: 10/20, Batch: 30/469, Loss: 0.214365 Epoch: 10/20, Batch: 40/469, Loss: 0.214507 Epoch: 10/20, Batch: 50/469, Loss: 0.213939 Epoch: 10/20, Batch: 60/469, Loss: 0.213126 Epoch: 10/20, Batch: 70/469, Loss: 0.208756 Epoch: 10/20, Batch: 80/469, Loss: 0.215561 Epoch: 10/20, Batch: 90/469, Loss: 0.210820 Epoch: 10/20, Batch: 100/469, Loss: 0.213531 Epoch: 10/20, Batch: 110/469, Loss: 0.211186 Epoch: 10/20, Batch: 120/469, Loss: 0.211150 Epoch: 10/20, Batch: 130/469, Loss: 0.205882 Epoch: 10/20, Batch: 140/469, Loss: 0.206015 Epoch: 10/20, Batch: 150/469, Loss: 0.205237 Epoch: 10/20, Batch: 160/469, Loss: 0.208499 Epoch: 10/20, Batch: 170/469, Loss: 0.213678 Epoch: 10/20, Batch: 180/469, Loss: 0.205830 Epoch: 10/20, Batch: 190/469, Loss: 0.220421 Epoch: 10/20, Batch: 200/469, Loss: 0.209436 Epoch: 10/20, Batch: 210/469, Loss: 0.208071 Epoch: 10/20, Batch: 220/469, Loss: 0.213684 Epoch: 10/20, Batch: 230/469, Loss: 0.211148 Epoch: 10/20, Batch: 240/469, Loss: 0.209753 Epoch: 10/20, Batch: 250/469, Loss: 0.207899 Epoch: 10/20, Batch: 260/469, Loss: 0.214689 Epoch: 10/20, Batch: 270/469, Loss: 0.203741 Epoch: 10/20, Batch: 280/469, Loss: 0.216980 Epoch: 10/20, Batch: 290/469, Loss: 0.215852 Epoch: 10/20, Batch: 300/469, Loss: 0.207830 Epoch: 10/20, Batch: 310/469, Loss: 0.216609 Epoch: 10/20, Batch: 320/469, Loss: 0.213840 Epoch: 10/20, Batch: 330/469, Loss: 0.209219 Epoch: 10/20, Batch: 340/469, Loss: 0.217362 Epoch: 10/20, Batch: 350/469, Loss: 0.216105 Epoch: 10/20, Batch: 360/469, Loss: 0.215934 Epoch: 10/20, Batch: 370/469, Loss: 0.211426 Epoch: 10/20, Batch: 380/469, Loss: 0.213747 Epoch: 10/20, Batch: 390/469, Loss: 0.213301 Epoch: 10/20, Batch: 400/469, Loss: 0.214042 Epoch: 10/20, Batch: 410/469, Loss: 0.202593 Epoch: 10/20, Batch: 420/469, Loss: 0.216908 Epoch: 10/20, Batch: 430/469, Loss: 0.214638 Epoch: 10/20, Batch: 440/469, Loss: 0.204698 Epoch: 10/20, Batch: 450/469, Loss: 0.210592 Epoch: 10/20, Batch: 460/469, Loss: 0.213882 Epoch 10 average training loss: 0.212729 Epoch 10 test loss: 0.211134 Current learning rate: 0.001000 New best model with test loss: 0.211134 Epoch: 11/20, Batch: 0/469, Loss: 0.214763 Epoch: 11/20, Batch: 10/469, Loss: 0.219020 Epoch: 11/20, Batch: 20/469, Loss: 0.213717 Epoch: 11/20, Batch: 30/469, Loss: 0.209371 Epoch: 11/20, Batch: 40/469, Loss: 0.207829 Epoch: 11/20, Batch: 50/469, Loss: 0.207739 Epoch: 11/20, Batch: 60/469, Loss: 0.216472 Epoch: 11/20, Batch: 70/469, Loss: 0.211675 Epoch: 11/20, Batch: 80/469, Loss: 0.210735 Epoch: 11/20, Batch: 90/469, Loss: 0.214069 Epoch: 11/20, Batch: 100/469, Loss: 0.212001 Epoch: 11/20, Batch: 110/469, Loss: 0.212480 Epoch: 11/20, Batch: 120/469, Loss: 0.216690 Epoch: 11/20, Batch: 130/469, Loss: 0.206841 Epoch: 11/20, Batch: 140/469, Loss: 0.216091 Epoch: 11/20, Batch: 150/469, Loss: 0.209724 Epoch: 11/20, Batch: 160/469, Loss: 0.212609 Epoch: 11/20, Batch: 170/469, Loss: 0.211049 Epoch: 11/20, Batch: 180/469, Loss: 0.219174 Epoch: 11/20, Batch: 190/469, Loss: 0.211215 Epoch: 11/20, Batch: 200/469, Loss: 0.220481 Epoch: 11/20, Batch: 210/469, Loss: 0.215575 Epoch: 11/20, Batch: 220/469, Loss: 0.212099 Epoch: 11/20, Batch: 230/469, Loss: 0.215937 Epoch: 11/20, Batch: 240/469, Loss: 0.215622 Epoch: 11/20, Batch: 250/469, Loss: 0.209118 Epoch: 11/20, Batch: 260/469, Loss: 0.219197 Epoch: 11/20, Batch: 270/469, Loss: 0.204003 Epoch: 11/20, Batch: 280/469, Loss: 0.214908 Epoch: 11/20, Batch: 290/469, Loss: 0.208491 Epoch: 11/20, Batch: 300/469, Loss: 0.207092 Epoch: 11/20, Batch: 310/469, Loss: 0.215001 Epoch: 11/20, Batch: 320/469, Loss: 0.215088 Epoch: 11/20, Batch: 330/469, Loss: 0.214138 Epoch: 11/20, Batch: 340/469, Loss: 0.215318 Epoch: 11/20, Batch: 350/469, Loss: 0.207700 Epoch: 11/20, Batch: 360/469, Loss: 0.215858 Epoch: 11/20, Batch: 370/469, Loss: 0.213861 Epoch: 11/20, Batch: 380/469, Loss: 0.207773 Epoch: 11/20, Batch: 390/469, Loss: 0.214390 Epoch: 11/20, Batch: 400/469, Loss: 0.213081 Epoch: 11/20, Batch: 410/469, Loss: 0.203974 Epoch: 11/20, Batch: 420/469, Loss: 0.215607 Epoch: 11/20, Batch: 430/469, Loss: 0.214710 Epoch: 11/20, Batch: 440/469, Loss: 0.200902 Epoch: 11/20, Batch: 450/469, Loss: 0.206424 Epoch: 11/20, Batch: 460/469, Loss: 0.210066 Epoch 11 average training loss: 0.211831 Epoch 11 test loss: 0.211603 Current learning rate: 0.001000 No improvement for 1 epochs Epoch: 12/20, Batch: 0/469, Loss: 0.211043 Epoch: 12/20, Batch: 10/469, Loss: 0.206815 Epoch: 12/20, Batch: 20/469, Loss: 0.215231 Epoch: 12/20, Batch: 30/469, Loss: 0.205282 Epoch: 12/20, Batch: 40/469, Loss: 0.207909 Epoch: 12/20, Batch: 50/469, Loss: 0.212512 Epoch: 12/20, Batch: 60/469, Loss: 0.213266 Epoch: 12/20, Batch: 70/469, Loss: 0.206314 Epoch: 12/20, Batch: 80/469, Loss: 0.211810 Epoch: 12/20, Batch: 90/469, Loss: 0.210911 Epoch: 12/20, Batch: 100/469, Loss: 0.210767 Epoch: 12/20, Batch: 110/469, Loss: 0.211135 Epoch: 12/20, Batch: 120/469, Loss: 0.217679 Epoch: 12/20, Batch: 130/469, Loss: 0.211737 Epoch: 12/20, Batch: 140/469, Loss: 0.211754 Epoch: 12/20, Batch: 150/469, Loss: 0.211974 Epoch: 12/20, Batch: 160/469, Loss: 0.210566 Epoch: 12/20, Batch: 170/469, Loss: 0.217121 Epoch: 12/20, Batch: 180/469, Loss: 0.221854 Epoch: 12/20, Batch: 190/469, Loss: 0.209463 Epoch: 12/20, Batch: 200/469, Loss: 0.213184 Epoch: 12/20, Batch: 210/469, Loss: 0.213088 Epoch: 12/20, Batch: 220/469, Loss: 0.216206 Epoch: 12/20, Batch: 230/469, Loss: 0.211522 Epoch: 12/20, Batch: 240/469, Loss: 0.218441 Epoch: 12/20, Batch: 250/469, Loss: 0.213431 Epoch: 12/20, Batch: 260/469, Loss: 0.211362 Epoch: 12/20, Batch: 270/469, Loss: 0.216138 Epoch: 12/20, Batch: 280/469, Loss: 0.210831 Epoch: 12/20, Batch: 290/469, Loss: 0.207998 Epoch: 12/20, Batch: 300/469, Loss: 0.216001 Epoch: 12/20, Batch: 310/469, Loss: 0.210044 Epoch: 12/20, Batch: 320/469, Loss: 0.212161 Epoch: 12/20, Batch: 330/469, Loss: 0.212306 Epoch: 12/20, Batch: 340/469, Loss: 0.203685 Epoch: 12/20, Batch: 350/469, Loss: 0.210994 Epoch: 12/20, Batch: 360/469, Loss: 0.218511 Epoch: 12/20, Batch: 370/469, Loss: 0.207289 Epoch: 12/20, Batch: 380/469, Loss: 0.213954 Epoch: 12/20, Batch: 390/469, Loss: 0.213009 Epoch: 12/20, Batch: 400/469, Loss: 0.210243 Epoch: 12/20, Batch: 410/469, Loss: 0.217228 Epoch: 12/20, Batch: 420/469, Loss: 0.204950 Epoch: 12/20, Batch: 430/469, Loss: 0.206365 Epoch: 12/20, Batch: 440/469, Loss: 0.207523 Epoch: 12/20, Batch: 450/469, Loss: 0.212558 Epoch: 12/20, Batch: 460/469, Loss: 0.214329 Epoch 12 average training loss: 0.211093 Epoch 12 test loss: 0.209030 Current learning rate: 0.001000 New best model with test loss: 0.209030 Epoch: 13/20, Batch: 0/469, Loss: 0.210873 Epoch: 13/20, Batch: 10/469, Loss: 0.206721 Epoch: 13/20, Batch: 20/469, Loss: 0.219018 Epoch: 13/20, Batch: 30/469, Loss: 0.212593 Epoch: 13/20, Batch: 40/469, Loss: 0.212240 Epoch: 13/20, Batch: 50/469, Loss: 0.217745 Epoch: 13/20, Batch: 60/469, Loss: 0.212483 Epoch: 13/20, Batch: 70/469, Loss: 0.210983 Epoch: 13/20, Batch: 80/469, Loss: 0.210471 Epoch: 13/20, Batch: 90/469, Loss: 0.206260 Epoch: 13/20, Batch: 100/469, Loss: 0.202827 Epoch: 13/20, Batch: 110/469, Loss: 0.212014 Epoch: 13/20, Batch: 120/469, Loss: 0.206353 Epoch: 13/20, Batch: 130/469, Loss: 0.212223 Epoch: 13/20, Batch: 140/469, Loss: 0.209265 Epoch: 13/20, Batch: 150/469, Loss: 0.209393 Epoch: 13/20, Batch: 160/469, Loss: 0.207558 Epoch: 13/20, Batch: 170/469, Loss: 0.208564 Epoch: 13/20, Batch: 180/469, Loss: 0.206930 Epoch: 13/20, Batch: 190/469, Loss: 0.214779 Epoch: 13/20, Batch: 200/469, Loss: 0.211891 Epoch: 13/20, Batch: 210/469, Loss: 0.202581 Epoch: 13/20, Batch: 220/469, Loss: 0.211638 Epoch: 13/20, Batch: 230/469, Loss: 0.215342 Epoch: 13/20, Batch: 240/469, Loss: 0.208895 Epoch: 13/20, Batch: 250/469, Loss: 0.212216 Epoch: 13/20, Batch: 260/469, Loss: 0.207107 Epoch: 13/20, Batch: 270/469, Loss: 0.203087 Epoch: 13/20, Batch: 280/469, Loss: 0.212842 Epoch: 13/20, Batch: 290/469, Loss: 0.207821 Epoch: 13/20, Batch: 300/469, Loss: 0.210786 Epoch: 13/20, Batch: 310/469, Loss: 0.202894 Epoch: 13/20, Batch: 320/469, Loss: 0.204934 Epoch: 13/20, Batch: 330/469, Loss: 0.209444 Epoch: 13/20, Batch: 340/469, Loss: 0.211696 Epoch: 13/20, Batch: 350/469, Loss: 0.211397 Epoch: 13/20, Batch: 360/469, Loss: 0.212550 Epoch: 13/20, Batch: 370/469, Loss: 0.206205 Epoch: 13/20, Batch: 380/469, Loss: 0.206187 Epoch: 13/20, Batch: 390/469, Loss: 0.214461 Epoch: 13/20, Batch: 400/469, Loss: 0.208282 Epoch: 13/20, Batch: 410/469, Loss: 0.212585 Epoch: 13/20, Batch: 420/469, Loss: 0.205367 Epoch: 13/20, Batch: 430/469, Loss: 0.200762 Epoch: 13/20, Batch: 440/469, Loss: 0.207483 Epoch: 13/20, Batch: 450/469, Loss: 0.209540 Epoch: 13/20, Batch: 460/469, Loss: 0.216215 Epoch 13 average training loss: 0.210340 Epoch 13 test loss: 0.208866 Current learning rate: 0.001000 New best model with test loss: 0.208866 Epoch: 14/20, Batch: 0/469, Loss: 0.212771 Epoch: 14/20, Batch: 10/469, Loss: 0.213373 Epoch: 14/20, Batch: 20/469, Loss: 0.199179 Epoch: 14/20, Batch: 30/469, Loss: 0.208138 Epoch: 14/20, Batch: 40/469, Loss: 0.205600 Epoch: 14/20, Batch: 50/469, Loss: 0.200619 Epoch: 14/20, Batch: 60/469, Loss: 0.209880 Epoch: 14/20, Batch: 70/469, Loss: 0.209947 Epoch: 14/20, Batch: 80/469, Loss: 0.221133 Epoch: 14/20, Batch: 90/469, Loss: 0.212685 Epoch: 14/20, Batch: 100/469, Loss: 0.213046 Epoch: 14/20, Batch: 110/469, Loss: 0.209327 Epoch: 14/20, Batch: 120/469, Loss: 0.208958 Epoch: 14/20, Batch: 130/469, Loss: 0.209591 Epoch: 14/20, Batch: 140/469, Loss: 0.205718 Epoch: 14/20, Batch: 150/469, Loss: 0.207877 Epoch: 14/20, Batch: 160/469, Loss: 0.210706 Epoch: 14/20, Batch: 170/469, Loss: 0.213620 Epoch: 14/20, Batch: 180/469, Loss: 0.206293 Epoch: 14/20, Batch: 190/469, Loss: 0.205369 Epoch: 14/20, Batch: 200/469, Loss: 0.201762 Epoch: 14/20, Batch: 210/469, Loss: 0.211083 Epoch: 14/20, Batch: 220/469, Loss: 0.205765 Epoch: 14/20, Batch: 230/469, Loss: 0.208427 Epoch: 14/20, Batch: 240/469, Loss: 0.209432 Epoch: 14/20, Batch: 250/469, Loss: 0.206332 Epoch: 14/20, Batch: 260/469, Loss: 0.206169 Epoch: 14/20, Batch: 270/469, Loss: 0.210023 Epoch: 14/20, Batch: 280/469, Loss: 0.211498 Epoch: 14/20, Batch: 290/469, Loss: 0.206233 Epoch: 14/20, Batch: 300/469, Loss: 0.208469 Epoch: 14/20, Batch: 310/469, Loss: 0.197393 Epoch: 14/20, Batch: 320/469, Loss: 0.213416 Epoch: 14/20, Batch: 330/469, Loss: 0.202493 Epoch: 14/20, Batch: 340/469, Loss: 0.210662 Epoch: 14/20, Batch: 350/469, Loss: 0.214152 Epoch: 14/20, Batch: 360/469, Loss: 0.209098 Epoch: 14/20, Batch: 370/469, Loss: 0.208248 Epoch: 14/20, Batch: 380/469, Loss: 0.203246 Epoch: 14/20, Batch: 390/469, Loss: 0.209957 Epoch: 14/20, Batch: 400/469, Loss: 0.207177 Epoch: 14/20, Batch: 410/469, Loss: 0.204137 Epoch: 14/20, Batch: 420/469, Loss: 0.209079 Epoch: 14/20, Batch: 430/469, Loss: 0.210575 Epoch: 14/20, Batch: 440/469, Loss: 0.209693 Epoch: 14/20, Batch: 450/469, Loss: 0.216841 Epoch: 14/20, Batch: 460/469, Loss: 0.211794 Epoch 14 average training loss: 0.209595 Epoch 14 test loss: 0.209902 Current learning rate: 0.001000 No improvement for 1 epochs Epoch: 15/20, Batch: 0/469, Loss: 0.207367 Epoch: 15/20, Batch: 10/469, Loss: 0.204525 Epoch: 15/20, Batch: 20/469, Loss: 0.208534 Epoch: 15/20, Batch: 30/469, Loss: 0.213829 Epoch: 15/20, Batch: 40/469, Loss: 0.206152 Epoch: 15/20, Batch: 50/469, Loss: 0.206549 Epoch: 15/20, Batch: 60/469, Loss: 0.211006 Epoch: 15/20, Batch: 70/469, Loss: 0.201411 Epoch: 15/20, Batch: 80/469, Loss: 0.208558 Epoch: 15/20, Batch: 90/469, Loss: 0.208011 Epoch: 15/20, Batch: 100/469, Loss: 0.219043 Epoch: 15/20, Batch: 110/469, Loss: 0.204353 Epoch: 15/20, Batch: 120/469, Loss: 0.212703 Epoch: 15/20, Batch: 130/469, Loss: 0.211189 Epoch: 15/20, Batch: 140/469, Loss: 0.209152 Epoch: 15/20, Batch: 150/469, Loss: 0.206217 Epoch: 15/20, Batch: 160/469, Loss: 0.207095 Epoch: 15/20, Batch: 170/469, Loss: 0.201941 Epoch: 15/20, Batch: 180/469, Loss: 0.218086 Epoch: 15/20, Batch: 190/469, Loss: 0.204842 Epoch: 15/20, Batch: 200/469, Loss: 0.208185 Epoch: 15/20, Batch: 210/469, Loss: 0.212400 Epoch: 15/20, Batch: 220/469, Loss: 0.216292 Epoch: 15/20, Batch: 230/469, Loss: 0.212196 Epoch: 15/20, Batch: 240/469, Loss: 0.210542 Epoch: 15/20, Batch: 250/469, Loss: 0.202203 Epoch: 15/20, Batch: 260/469, Loss: 0.209709 Epoch: 15/20, Batch: 270/469, Loss: 0.203776 Epoch: 15/20, Batch: 280/469, Loss: 0.209241 Epoch: 15/20, Batch: 290/469, Loss: 0.203724 Epoch: 15/20, Batch: 300/469, Loss: 0.210055 Epoch: 15/20, Batch: 310/469, Loss: 0.209016 Epoch: 15/20, Batch: 320/469, Loss: 0.218477 Epoch: 15/20, Batch: 330/469, Loss: 0.207926 Epoch: 15/20, Batch: 340/469, Loss: 0.213963 Epoch: 15/20, Batch: 350/469, Loss: 0.206167 Epoch: 15/20, Batch: 360/469, Loss: 0.212611 Epoch: 15/20, Batch: 370/469, Loss: 0.205985 Epoch: 15/20, Batch: 380/469, Loss: 0.206174 Epoch: 15/20, Batch: 390/469, Loss: 0.202105 Epoch: 15/20, Batch: 400/469, Loss: 0.208659 Epoch: 15/20, Batch: 410/469, Loss: 0.207371 Epoch: 15/20, Batch: 420/469, Loss: 0.212601 Epoch: 15/20, Batch: 430/469, Loss: 0.204636 Epoch: 15/20, Batch: 440/469, Loss: 0.216152 Epoch: 15/20, Batch: 450/469, Loss: 0.207220 Epoch: 15/20, Batch: 460/469, Loss: 0.208323 Epoch 15 average training loss: 0.209439 Epoch 15 test loss: 0.209275 Current learning rate: 0.001000 No improvement for 2 epochs Epoch: 16/20, Batch: 0/469, Loss: 0.215293 Epoch: 16/20, Batch: 10/469, Loss: 0.213600 Epoch: 16/20, Batch: 20/469, Loss: 0.210042 Epoch: 16/20, Batch: 30/469, Loss: 0.210281 Epoch: 16/20, Batch: 40/469, Loss: 0.206723 Epoch: 16/20, Batch: 50/469, Loss: 0.210724 Epoch: 16/20, Batch: 60/469, Loss: 0.215392 Epoch: 16/20, Batch: 70/469, Loss: 0.212187 Epoch: 16/20, Batch: 80/469, Loss: 0.213717 Epoch: 16/20, Batch: 90/469, Loss: 0.208688 Epoch: 16/20, Batch: 100/469, Loss: 0.208639 Epoch: 16/20, Batch: 110/469, Loss: 0.212730 Epoch: 16/20, Batch: 120/469, Loss: 0.217212 Epoch: 16/20, Batch: 130/469, Loss: 0.205065 Epoch: 16/20, Batch: 140/469, Loss: 0.210744 Epoch: 16/20, Batch: 150/469, Loss: 0.205616 Epoch: 16/20, Batch: 160/469, Loss: 0.206760 Epoch: 16/20, Batch: 170/469, Loss: 0.205589 Epoch: 16/20, Batch: 180/469, Loss: 0.208279 Epoch: 16/20, Batch: 190/469, Loss: 0.215898 Epoch: 16/20, Batch: 200/469, Loss: 0.209263 Epoch: 16/20, Batch: 210/469, Loss: 0.209366 Epoch: 16/20, Batch: 220/469, Loss: 0.203065 Epoch: 16/20, Batch: 230/469, Loss: 0.206546 Epoch: 16/20, Batch: 240/469, Loss: 0.207617 Epoch: 16/20, Batch: 250/469, Loss: 0.209491 Epoch: 16/20, Batch: 260/469, Loss: 0.213042 Epoch: 16/20, Batch: 270/469, Loss: 0.213224 Epoch: 16/20, Batch: 280/469, Loss: 0.209430 Epoch: 16/20, Batch: 290/469, Loss: 0.212271 Epoch: 16/20, Batch: 300/469, Loss: 0.212083 Epoch: 16/20, Batch: 310/469, Loss: 0.210680 Epoch: 16/20, Batch: 320/469, Loss: 0.208942 Epoch: 16/20, Batch: 330/469, Loss: 0.209899 Epoch: 16/20, Batch: 340/469, Loss: 0.215008 Epoch: 16/20, Batch: 350/469, Loss: 0.207944 Epoch: 16/20, Batch: 360/469, Loss: 0.214226 Epoch: 16/20, Batch: 370/469, Loss: 0.205244 Epoch: 16/20, Batch: 380/469, Loss: 0.201590 Epoch: 16/20, Batch: 390/469, Loss: 0.202800 Epoch: 16/20, Batch: 400/469, Loss: 0.205809 Epoch: 16/20, Batch: 410/469, Loss: 0.202333 Epoch: 16/20, Batch: 420/469, Loss: 0.214613 Epoch: 16/20, Batch: 430/469, Loss: 0.206397 Epoch: 16/20, Batch: 440/469, Loss: 0.208402 Epoch: 16/20, Batch: 450/469, Loss: 0.202472 Epoch: 16/20, Batch: 460/469, Loss: 0.203844 Epoch 16 average training loss: 0.209325 Epoch 16 test loss: 0.206827 Current learning rate: 0.001000 New best model with test loss: 0.206827 Epoch: 17/20, Batch: 0/469, Loss: 0.202556 Epoch: 17/20, Batch: 10/469, Loss: 0.207812 Epoch: 17/20, Batch: 20/469, Loss: 0.204858 Epoch: 17/20, Batch: 30/469, Loss: 0.211211 Epoch: 17/20, Batch: 40/469, Loss: 0.200099 Epoch: 17/20, Batch: 50/469, Loss: 0.206627 Epoch: 17/20, Batch: 60/469, Loss: 0.208037 Epoch: 17/20, Batch: 70/469, Loss: 0.208377 Epoch: 17/20, Batch: 80/469, Loss: 0.210629 Epoch: 17/20, Batch: 90/469, Loss: 0.214234 Epoch: 17/20, Batch: 100/469, Loss: 0.212004 Epoch: 17/20, Batch: 110/469, Loss: 0.212999 Epoch: 17/20, Batch: 120/469, Loss: 0.209832 Epoch: 17/20, Batch: 130/469, Loss: 0.210132 Epoch: 17/20, Batch: 140/469, Loss: 0.209750 Epoch: 17/20, Batch: 150/469, Loss: 0.216717 Epoch: 17/20, Batch: 160/469, Loss: 0.201560 Epoch: 17/20, Batch: 170/469, Loss: 0.209595 Epoch: 17/20, Batch: 180/469, Loss: 0.209703 Epoch: 17/20, Batch: 190/469, Loss: 0.218761 Epoch: 17/20, Batch: 200/469, Loss: 0.206203 Epoch: 17/20, Batch: 210/469, Loss: 0.212779 Epoch: 17/20, Batch: 220/469, Loss: 0.205639 Epoch: 17/20, Batch: 230/469, Loss: 0.210037 Epoch: 17/20, Batch: 240/469, Loss: 0.209341 Epoch: 17/20, Batch: 250/469, Loss: 0.213321 Epoch: 17/20, Batch: 260/469, Loss: 0.204016 Epoch: 17/20, Batch: 270/469, Loss: 0.206974 Epoch: 17/20, Batch: 280/469, Loss: 0.211707 Epoch: 17/20, Batch: 290/469, Loss: 0.218377 Epoch: 17/20, Batch: 300/469, Loss: 0.209888 Epoch: 17/20, Batch: 310/469, Loss: 0.213094 Epoch: 17/20, Batch: 320/469, Loss: 0.206226 Epoch: 17/20, Batch: 330/469, Loss: 0.207443 Epoch: 17/20, Batch: 340/469, Loss: 0.206221 Epoch: 17/20, Batch: 350/469, Loss: 0.209708 Epoch: 17/20, Batch: 360/469, Loss: 0.208175 Epoch: 17/20, Batch: 370/469, Loss: 0.215754 Epoch: 17/20, Batch: 380/469, Loss: 0.206180 Epoch: 17/20, Batch: 390/469, Loss: 0.213457 Epoch: 17/20, Batch: 400/469, Loss: 0.207036 Epoch: 17/20, Batch: 410/469, Loss: 0.207870 Epoch: 17/20, Batch: 420/469, Loss: 0.207152 Epoch: 17/20, Batch: 430/469, Loss: 0.207636 Epoch: 17/20, Batch: 440/469, Loss: 0.206491 Epoch: 17/20, Batch: 450/469, Loss: 0.208486 Epoch: 17/20, Batch: 460/469, Loss: 0.207429 Epoch 17 average training loss: 0.209234 Epoch 17 test loss: 0.207524 Current learning rate: 0.001000 No improvement for 1 epochs Epoch: 18/20, Batch: 0/469, Loss: 0.202099 Epoch: 18/20, Batch: 10/469, Loss: 0.214507 Epoch: 18/20, Batch: 20/469, Loss: 0.207212 Epoch: 18/20, Batch: 30/469, Loss: 0.216066 Epoch: 18/20, Batch: 40/469, Loss: 0.213026 Epoch: 18/20, Batch: 50/469, Loss: 0.220747 Epoch: 18/20, Batch: 60/469, Loss: 0.210804 Epoch: 18/20, Batch: 70/469, Loss: 0.214038 Epoch: 18/20, Batch: 80/469, Loss: 0.207343 Epoch: 18/20, Batch: 90/469, Loss: 0.214469 Epoch: 18/20, Batch: 100/469, Loss: 0.208596 Epoch: 18/20, Batch: 110/469, Loss: 0.212159 Epoch: 18/20, Batch: 120/469, Loss: 0.204651 Epoch: 18/20, Batch: 130/469, Loss: 0.211621 Epoch: 18/20, Batch: 140/469, Loss: 0.210633 Epoch: 18/20, Batch: 150/469, Loss: 0.210425 Epoch: 18/20, Batch: 160/469, Loss: 0.203219 Epoch: 18/20, Batch: 170/469, Loss: 0.211141 Epoch: 18/20, Batch: 180/469, Loss: 0.209775 Epoch: 18/20, Batch: 190/469, Loss: 0.211429 Epoch: 18/20, Batch: 200/469, Loss: 0.205494 Epoch: 18/20, Batch: 210/469, Loss: 0.218133 Epoch: 18/20, Batch: 220/469, Loss: 0.204454 Epoch: 18/20, Batch: 230/469, Loss: 0.210245 Epoch: 18/20, Batch: 240/469, Loss: 0.205546 Epoch: 18/20, Batch: 250/469, Loss: 0.207433 Epoch: 18/20, Batch: 260/469, Loss: 0.208362 Epoch: 18/20, Batch: 270/469, Loss: 0.209357 Epoch: 18/20, Batch: 280/469, Loss: 0.211110 Epoch: 18/20, Batch: 290/469, Loss: 0.209117 Epoch: 18/20, Batch: 300/469, Loss: 0.204339 Epoch: 18/20, Batch: 310/469, Loss: 0.208376 Epoch: 18/20, Batch: 320/469, Loss: 0.213858 Epoch: 18/20, Batch: 330/469, Loss: 0.213692 Epoch: 18/20, Batch: 340/469, Loss: 0.207779 Epoch: 18/20, Batch: 350/469, Loss: 0.205959 Epoch: 18/20, Batch: 360/469, Loss: 0.203879 Epoch: 18/20, Batch: 370/469, Loss: 0.204162 Epoch: 18/20, Batch: 380/469, Loss: 0.205290 Epoch: 18/20, Batch: 390/469, Loss: 0.212493 Epoch: 18/20, Batch: 400/469, Loss: 0.200532 Epoch: 18/20, Batch: 410/469, Loss: 0.211167 Epoch: 18/20, Batch: 420/469, Loss: 0.208053 Epoch: 18/20, Batch: 430/469, Loss: 0.214512 Epoch: 18/20, Batch: 440/469, Loss: 0.201038 Epoch: 18/20, Batch: 450/469, Loss: 0.213429 Epoch: 18/20, Batch: 460/469, Loss: 0.204935 Epoch 18 average training loss: 0.209099 Epoch 18 test loss: 0.206903 Current learning rate: 0.001000 No improvement for 2 epochs Epoch: 19/20, Batch: 0/469, Loss: 0.207390 Epoch: 19/20, Batch: 10/469, Loss: 0.206777 Epoch: 19/20, Batch: 20/469, Loss: 0.212507 Epoch: 19/20, Batch: 30/469, Loss: 0.209222 Epoch: 19/20, Batch: 40/469, Loss: 0.203609 Epoch: 19/20, Batch: 50/469, Loss: 0.207364 Epoch: 19/20, Batch: 60/469, Loss: 0.202807 Epoch: 19/20, Batch: 70/469, Loss: 0.206227 Epoch: 19/20, Batch: 80/469, Loss: 0.207104 Epoch: 19/20, Batch: 90/469, Loss: 0.206067 Epoch: 19/20, Batch: 100/469, Loss: 0.207564 Epoch: 19/20, Batch: 110/469, Loss: 0.206372 Epoch: 19/20, Batch: 120/469, Loss: 0.212251 Epoch: 19/20, Batch: 130/469, Loss: 0.217493 Epoch: 19/20, Batch: 140/469, Loss: 0.208805 Epoch: 19/20, Batch: 150/469, Loss: 0.213019 Epoch: 19/20, Batch: 160/469, Loss: 0.210528 Epoch: 19/20, Batch: 170/469, Loss: 0.209955 Epoch: 19/20, Batch: 180/469, Loss: 0.211152 Epoch: 19/20, Batch: 190/469, Loss: 0.205525 Epoch: 19/20, Batch: 200/469, Loss: 0.211048 Epoch: 19/20, Batch: 210/469, Loss: 0.212028 Epoch: 19/20, Batch: 220/469, Loss: 0.208617 Epoch: 19/20, Batch: 230/469, Loss: 0.208809 Epoch: 19/20, Batch: 240/469, Loss: 0.203364 Epoch: 19/20, Batch: 250/469, Loss: 0.210144 Epoch: 19/20, Batch: 260/469, Loss: 0.210097 Epoch: 19/20, Batch: 270/469, Loss: 0.214254 Epoch: 19/20, Batch: 280/469, Loss: 0.208904 Epoch: 19/20, Batch: 290/469, Loss: 0.218724 Epoch: 19/20, Batch: 300/469, Loss: 0.199022 Epoch: 19/20, Batch: 310/469, Loss: 0.210516 Epoch: 19/20, Batch: 320/469, Loss: 0.209234 Epoch: 19/20, Batch: 330/469, Loss: 0.210159 Epoch: 19/20, Batch: 340/469, Loss: 0.198755 Epoch: 19/20, Batch: 350/469, Loss: 0.211019 Epoch: 19/20, Batch: 360/469, Loss: 0.199160 Epoch: 19/20, Batch: 370/469, Loss: 0.210670 Epoch: 19/20, Batch: 380/469, Loss: 0.212555 Epoch: 19/20, Batch: 390/469, Loss: 0.212382 Epoch: 19/20, Batch: 400/469, Loss: 0.207561 Epoch: 19/20, Batch: 410/469, Loss: 0.209851 Epoch: 19/20, Batch: 420/469, Loss: 0.210867 Epoch: 19/20, Batch: 430/469, Loss: 0.209916 Epoch: 19/20, Batch: 440/469, Loss: 0.206035 Epoch: 19/20, Batch: 450/469, Loss: 0.206083 Epoch: 19/20, Batch: 460/469, Loss: 0.207148 Epoch 19 average training loss: 0.208926 Epoch 19 test loss: 0.210435 Current learning rate: 0.000500 No improvement for 3 epochs Epoch: 20/20, Batch: 0/469, Loss: 0.211074 Epoch: 20/20, Batch: 10/469, Loss: 0.204439 Epoch: 20/20, Batch: 20/469, Loss: 0.208146 Epoch: 20/20, Batch: 30/469, Loss: 0.201846 Epoch: 20/20, Batch: 40/469, Loss: 0.207145 Epoch: 20/20, Batch: 50/469, Loss: 0.207307 Epoch: 20/20, Batch: 60/469, Loss: 0.202797 Epoch: 20/20, Batch: 70/469, Loss: 0.210327 Epoch: 20/20, Batch: 80/469, Loss: 0.205838 Epoch: 20/20, Batch: 90/469, Loss: 0.203876 Epoch: 20/20, Batch: 100/469, Loss: 0.205854 Epoch: 20/20, Batch: 110/469, Loss: 0.204464 Epoch: 20/20, Batch: 120/469, Loss: 0.206882 Epoch: 20/20, Batch: 130/469, Loss: 0.206017 Epoch: 20/20, Batch: 140/469, Loss: 0.202314 Epoch: 20/20, Batch: 150/469, Loss: 0.209197 Epoch: 20/20, Batch: 160/469, Loss: 0.204208 Epoch: 20/20, Batch: 170/469, Loss: 0.203223 Epoch: 20/20, Batch: 180/469, Loss: 0.203935 Epoch: 20/20, Batch: 190/469, Loss: 0.206490 Epoch: 20/20, Batch: 200/469, Loss: 0.204954 Epoch: 20/20, Batch: 210/469, Loss: 0.207911 Epoch: 20/20, Batch: 220/469, Loss: 0.210991 Epoch: 20/20, Batch: 230/469, Loss: 0.199542 Epoch: 20/20, Batch: 240/469, Loss: 0.202942 Epoch: 20/20, Batch: 250/469, Loss: 0.207271 Epoch: 20/20, Batch: 260/469, Loss: 0.203368 Epoch: 20/20, Batch: 270/469, Loss: 0.198808 Epoch: 20/20, Batch: 280/469, Loss: 0.203112 Epoch: 20/20, Batch: 290/469, Loss: 0.203000 Epoch: 20/20, Batch: 300/469, Loss: 0.202056 Epoch: 20/20, Batch: 310/469, Loss: 0.205928 Epoch: 20/20, Batch: 320/469, Loss: 0.201867 Epoch: 20/20, Batch: 330/469, Loss: 0.204746 Epoch: 20/20, Batch: 340/469, Loss: 0.201211 Epoch: 20/20, Batch: 350/469, Loss: 0.204882 Epoch: 20/20, Batch: 360/469, Loss: 0.205628 Epoch: 20/20, Batch: 370/469, Loss: 0.203249 Epoch: 20/20, Batch: 380/469, Loss: 0.205749 Epoch: 20/20, Batch: 390/469, Loss: 0.204313 Epoch: 20/20, Batch: 400/469, Loss: 0.201741 Epoch: 20/20, Batch: 410/469, Loss: 0.207491 Epoch: 20/20, Batch: 420/469, Loss: 0.205125 Epoch: 20/20, Batch: 430/469, Loss: 0.196546 Epoch: 20/20, Batch: 440/469, Loss: 0.207992 Epoch: 20/20, Batch: 450/469, Loss: 0.203138 Epoch: 20/20, Batch: 460/469, Loss: 0.207264 Epoch 20 average training loss: 0.205626 Epoch 20 test loss: 0.204501 Current learning rate: 0.000500 New best model with test loss: 0.204501 Loaded best model with test loss: 0.204501 Final Test Loss: 0.2045
samples shape: (100, 28, 28, 3)
Question 3: Causal Transformer - iGPT¶
Now we will move onto the current most popular and widespread autoregressive model, the transformer.
Part (a) Autoregressive Transformer on Shapes and MNIST¶
In this part, implement a simple Autoregressive Transformer to model binary MNIST and shapes images (same as Q2(a), but with a Transformer).
Some additional notes about your transformer implementation:
- iGPT uses learned positional encodings. We recommend to use those here as well. However, you may also use sinusoidal positional encodings if you wish (see the Attention is All You Need paper)
- Autoregressive transformer always predicts the next token, give prior tokens. iGPT has a special <bos> or beginning of sequence token at the start of every sequence every image. Make sure to include this in your implementation as well. You can generate unconditional sample by conditioning with the <bos> token.
- While dropout is a common feature in transformer models, you do not need to add it (but may if you wish!).
- Prebuilt transformers exist in some frameworks (i.e. pytorch). Don't just use an off the shelf implementation as the point of the exercise is to better understand the transformer architecture. Building the transformer from the ground up (use primitives such as Linear/Dense layers, LayerNorm, GeLU, Embedding)
- Learning rate warmup and cos learning rate decay are often used when training transformers to improve training stability and improve performance. See if this helps your model! Try 1000 steps of warmup with a cosine learning rate decay.
Paper references
- Attention Is All You Need
- Generative Pretraining from Pixels
- Language Models are Unsupervised Multitask Learners
We recommend the following network design parameters:
- $d_{model}$: 128
- heads: 4
- layers: 2
- GeLU nonlinearities
And the following hyperparameters:
- Batch size: 64 or 32 or 16 (whichever fits in your GPU)
- Learning rate: $10^{-3}$
- 15 epochs or more
- Adam Optimizer (this applies to all Transformers models trained in future parts)
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- 100 samples from the final trained model
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.0):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
"""
q: (batch_size, n_heads, seq_len, head_size)
k: (batch_size, n_heads, seq_len, head_size)
v: (batch_size, n_heads, seq_len, head_size)
"""
d_k = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k) # (batch_size, n_heads, seq_len, seq_len)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1) # (batch_size, n_heads, seq_len, seq_len)
attention_weights = self.dropout(attention_weights) # (batch_size, n_heads, seq_len, seq_len)
output = torch.matmul(attention_weights, v) # (batch_size, n_heads, seq_len, head_size)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0, cache=False):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_size = d_model // n_heads
self.use_cache = cache
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.attention = ScaledDotProductAttention(dropout=dropout)
self.cached_k = None
self.cached_v = None
def split_heads(self, x):
"""
x: (batch_size, seq_len, d_model)
"""
batch_size, seq_len, d_model = x.shape
return x.view(batch_size, seq_len, self.n_heads, self.head_size).transpose(1, 2) # (batch_size, n_heads, seq_len, head_size)
def combine_heads(self, x):
"""
x: (batch_size, n_heads, seq_len, head_size)
"""
batch_size, n_heads, seq_len, head_size = x.shape
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) # (batch_size, seq_len, d_model)
def forward(self, x, mask=None, use_cache=False, past_key_values=None):
batch_size, seq_len, d_model = x.shape
if past_key_values is not None:
self.cached_k, self.cached_v = past_key_values
q = self.W_q(x) # (batch_size, seq_len, d_model)
k = self.W_k(x)
v = self.W_v(x)
q = self.split_heads(q) # (batch_size, n_heads, seq_len, head_size)
k = self.split_heads(k)
v = self.split_heads(v)
# Use KV cache if enabled
if use_cache and self.cached_k is not None and self.cached_v is not None:
# Concatenate current k, v with cached k, v
k = torch.cat([self.cached_k, k], dim=2)
v = torch.cat([self.cached_v, v], dim=2)
self.cached_k = k
self.cached_v = v
# Create causal mask if needed
if mask is None:
# If using cache, adjust mask to account for the full sequence length
full_seq_len = k.size(2)
# For cached version, we need to adjust the mask to allow attention to all past tokens
if use_cache and self.cached_k is not None:
# Create a mask where current tokens can attend to all previous tokens
# Current sequence position is at seq_len
seq_position = seq_len
# Create a mask that allows each token to see itself and all previous tokens
mask = torch.ones(seq_len, full_seq_len).to(x.device)
# Make it causal by setting future positions to 0
mask[:, seq_position:] = 0
else:
# Standard causal mask for the full sequence
mask = torch.tril(torch.ones(full_seq_len, full_seq_len)).to(x.device)
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
# Use the attention module directly
output = self.attention(q, k, v, mask) # (batch_size, n_heads, seq_len, head_size)
# Combine heads
output = self.combine_heads(output) # (batch_size, seq_len, d_model)
past_key_values = (k, v)
if use_cache:
return self.dropout(self.out(output)) , past_key_values
else:
return self.dropout(self.out(output))
def clear_cache(self):
self.cached_k = None
self.cached_v = None
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1, use_cache=False):
super().__init__()
self.masked_mha = MultiHeadAttention(d_model, n_heads, dropout, cache=use_cache)
self.layer_norm1 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
nn.Dropout(dropout)
)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, use_cache=False, past_key_values=None):
# Self-attention with residual connection and layer normalization
residual = x
x = self.layer_norm1(x) # Pre-norm architecture
if use_cache and past_key_values is not None:
x, past_key_values = self.masked_mha(x, use_cache=use_cache, past_key_values=past_key_values)
else:
x = self.masked_mha(x)
x = residual + x # Residual connection
# Feed forward with residual connection and layer normalization
residual = x
x = self.layer_norm2(x) # Pre-norm architecture
x = self.feed_forward(x)
x = residual + x # Residual connection
if use_cache:
return x , past_key_values
else:
return x
def clear_cache(self):
self.masked_mha.clear_cache()
class iGPT(nn.Module):
def __init__(self, vocab_size, context_length, d_model, n_heads, n_layers, dropout=0.1, use_cache=False):
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.dropout = dropout
self.use_cache = use_cache
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, d_model)
# Positional embedding (learned, as per iGPT specs)
self.position_embedding = nn.Embedding(context_length, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# Stack of decoder layers
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, dropout, use_cache=use_cache)
for _ in range(n_layers)
])
# Final layer norm
self.layer_norm = nn.LayerNorm(d_model)
# Output projection
self.output_projection = nn.Linear(d_model, vocab_size)
def forward(self, x, past_key_values=None, use_cache=False):
# x shape: (batch_size, seq_len)
batch_size, seq_len = x.shape
device = x.device
# Create position indices
positions = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
# Get embeddings
token_emb = self.token_embedding(x) # (batch_size, seq_len, d_model)
pos_emb = self.position_embedding(positions) # (batch_size, seq_len, d_model)
# Combine embeddings
x = token_emb + pos_emb # (batch_size, seq_len, d_model)
x = self.dropout(x)
# Apply decoder layers
past_key_values = None
for layer in self.decoder_layers:
if use_cache:
x, past_key_values = layer(x, use_cache=use_cache, past_key_values=past_key_values)
else:
x = layer(x)
# Apply final layer norm
x = self.layer_norm(x) # (batch_size, seq_len, d_model)
# Project to vocabulary
logits = self.output_projection(x) # (batch_size, seq_len, vocab_size)
if use_cache:
return logits, past_key_values
else:
return logits
def clear_cache(self):
for layer in self.decoder_layers:
layer.clear_cache()
def test_igpt():
# Define dummy parameters
vocab_size = 10
context_length = 20
d_model = 128
n_heads = 4
n_layers = 2
batch_size = 5
seq_len = context_length
# Create a dummy input tensor
dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len))
# Initialize the iGPT model
model = iGPT(vocab_size, context_length, d_model, n_heads, n_layers)
# Test token embedding
token_emb = model.token_embedding(dummy_input)
print("Token embedding shape:", token_emb.shape)
assert token_emb.shape == (batch_size, seq_len, d_model), "Token embedding shape mismatch!"
# Test position embedding
positions = torch.arange(0, seq_len, dtype=torch.long, device=dummy_input.device).unsqueeze(0).expand(batch_size, -1)
pos_emb = model.position_embedding(positions)
print("Position embedding shape:", pos_emb.shape)
assert pos_emb.shape == (batch_size, seq_len, d_model), "Position embedding shape mismatch!"
# Test each decoder layer
x = token_emb + pos_emb
x = model.dropout(x)
for i, layer in enumerate(model.decoder_layers):
x_before = x.clone()
x = layer(x)
print(f"Decoder layer {i} output shape:", x.shape)
assert x.shape == (batch_size, seq_len, d_model), f"Decoder layer {i} output shape mismatch!"
# Check that the layer actually modified the input
assert not torch.allclose(x, x_before), f"Decoder layer {i} did not modify the input!"
# Test final layer norm
x_before = x.clone()
x = model.layer_norm(x)
print("Layer norm output shape:", x.shape)
assert x.shape == (batch_size, seq_len, d_model), "Layer norm output shape mismatch!"
# Test output projection
logits = model.output_projection(x)
print("Output logits shape:", logits.shape)
assert logits.shape == (batch_size, seq_len, vocab_size), "Output logits shape mismatch!"
# Full forward pass
output = model(dummy_input)
print("Final output shape:", output.shape)
assert output.shape == (batch_size, seq_len, vocab_size), "Final output shape mismatch!"
print("iGPT model test passed! All layers are implemented correctly.")
# Run the test
test_igpt()
Token embedding shape: torch.Size([5, 20, 128]) Position embedding shape: torch.Size([5, 20, 128]) Decoder layer 0 output shape: torch.Size([5, 20, 128]) Decoder layer 1 output shape: torch.Size([5, 20, 128]) Layer norm output shape: torch.Size([5, 20, 128]) Output logits shape: torch.Size([5, 20, 10]) Final output shape: torch.Size([5, 20, 10]) iGPT model test passed! All layers are implemented correctly.
def generate_samples(model, sequence_length, vocab_size, image_shape, device, num_samples=100, use_cache=False, test_mode=False):
"""
Generates samples from the trained model.
Args:
model: The trained iGPT model
sequence_length: Length of token sequences including <bos>
vocab_size: Size of vocabulary
image_shape: (H, W, C) tuple specifying image dimensions
device: Device to run generation on
num_samples: Number of samples to generate
use_cache: Whether to use caching for faster sampling
test_mode: If True, only generate first 5 samples and fill the rest with blank images
Returns:
Numpy array of generated samples with shape (num_samples, H, W, C)
and a list of generation times
"""
H, W, C = image_shape
model.eval()
samples = []
import time
time_list = []
# Determine how many samples to actually generate
samples_to_generate = 5 if test_mode else num_samples
with torch.no_grad():
for i in range(num_samples):
if test_mode and i >= samples_to_generate:
# In test mode, fill remaining samples with blank images
if C == 3:
blank_sample = np.zeros((H, W, C), dtype=np.uint8)
else:
blank_sample = np.zeros((H, W, C), dtype=np.uint8)
samples.append(blank_sample)
time_list.append(0.0) # No time spent on blank images
continue
start_time = time.time()
# Start with just the <bos> token
sample = torch.zeros(1, sequence_length, dtype=torch.long, device=device)
sample[:, 0] = 0 # <bos> token
# Cache for key-value pairs if using caching
past_key_values = None
# Autoregressive generation - one token at a time
for i in range(1, sequence_length):
if use_cache and past_key_values is not None:
# Only need to process the new token with cached key-values
logits, past_key_values = model(sample[:, i-1:i], past_key_values=past_key_values, use_cache=True)
# print(f"logits shape: {logits.shape}")
logits = logits[:, -1, :] # Get prediction for current position
else:
# Process the entire sequence up to current position
if use_cache:
logits, past_key_values = model(sample[:, :i], use_cache=True)
logits = logits[:, -1, :] # Get prediction for current position
else:
logits = model(sample)
logits = logits[:, i-1, :] # Get prediction for current position
# print(f"past_key_values: {past_key_values}")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1).squeeze(-1)
sample[:, i] = next_token
# Convert tokens back to image format (remove <bos> token)
sample_tokens = sample[:, 1:].cpu().numpy().reshape(H, W)
# Convert single tokens back to RGB values if needed
if C == 3:
sample_rgb = np.zeros((H, W, C), dtype=np.uint8)
sample_rgb[:, :, 0] = sample_tokens // 16 # R = token // 16
sample_rgb[:, :, 1] = (sample_tokens % 16) // 4 # G = (token % 16) // 4
sample_rgb[:, :, 2] = sample_tokens % 4 # B = token % 4
samples.append(sample_rgb)
else:
samples.append(sample_tokens.reshape(H, W, C))
end_time = time.time()
time_list.append(end_time - start_time)
return np.array(samples), np.array(time_list)
import math
def create_dataset(data, image_shape, batch_size):
"""
Converts image data to token sequences and creates PyTorch DataLoader.
Args:
data: A (n_samples, H, W, C) uint8 numpy array of images
image_shape: (H, W, C) tuple specifying image dimensions
batch_size: Batch size for DataLoader
Returns:
DataLoader object with tokenized image sequences
"""
H, W, C = image_shape
# Convert RGB pixels to single tokens (4 values per channel = 64 possible values)
# Shape: (n_samples, H, W, C) -> (n_samples, H, W)
if C == 3:
# Convert RGB values to a single token: r*16 + g*4 + b
# Each channel has values in {0,1,2,3}, so we can encode as a single number 0-63
data_tokens = (data[:,:,:,0] * 16 + data[:,:,:,1] * 4 + data[:,:,:,2])
else:
# For grayscale, just use the values directly
data_tokens = data.reshape(-1, H, W)
# Flatten spatial dimensions to create sequences
# Shape: (n_samples, H, W) -> (n_samples, H*W)
data_flat = data_tokens.reshape(-1, H * W)
# Convert to PyTorch tensors
dataset = torch.utils.data.TensorDataset(torch.tensor(data_flat, dtype=torch.long))
# Create data loader
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
def evaluate_model(model, data_loader, sequence_length, vocab_size, device):
"""
Evaluates model performance on a dataset.
Args:
model: The iGPT model
data_loader: DataLoader containing tokenized images
sequence_length: Length of token sequences including <bos>
vocab_size: Size of vocabulary
device: Device to run evaluation on
Returns:
Average loss (negative log-likelihood) per dimension
"""
model.eval()
total_loss = 0
total_samples = 0
with torch.no_grad():
for (data,) in data_loader:
data = data.to(device) # Shape: (batch_size, sequence_length-1)
batch_size = data.size(0)
# Create input with <bos> token (0) at the beginning
# Shape: (batch_size, sequence_length)
input_seq = torch.zeros(batch_size, sequence_length, dtype=torch.long, device=device)
input_seq[:, 0] = 0 # <bos> token
input_seq[:, 1:] = data # actual image data
# Create targets (the image tokens to predict)
# Shape: (batch_size, sequence_length-1)
targets = data
# Forward pass
# Shape: (batch_size, sequence_length, vocab_size) -> (batch_size, sequence_length-1, vocab_size)
logits = model(input_seq)[:, :-1, :] # Remove last position's prediction
# Compute loss
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1), reduction='sum')
total_loss += loss.item()
total_samples += batch_size * (sequence_length - 1)
return total_loss / total_samples
def train_igpt(model, train_loader, test_loader, sequence_length, vocab_size,
device, num_epochs, learning_rate):
"""
Trains the iGPT model.
Args:
model: The iGPT model to train
train_loader: DataLoader for training data
test_loader: DataLoader for test data
sequence_length: Length of token sequences including <bos>
vocab_size: Size of vocabulary
device: Device to train on
num_epochs: Number of training epochs
learning_rate: Initial learning rate
Returns:
train_losses: Array of training losses per minibatch
test_losses: Array of test losses per epoch
"""
# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Learning rate scheduler with warmup and cosine decay
warmup_steps = 1000
total_steps = len(train_loader) * num_epochs
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Initialize arrays to store losses
train_losses = []
test_losses = [evaluate_model(model, test_loader, sequence_length, vocab_size, device)]
# Training loop
for epoch in range(num_epochs):
model.train()
epoch_losses = []
for batch_idx, (data,) in enumerate(train_loader):
data = data.to(device) # Shape: (batch_size, sequence_length-1)
batch_size = data.size(0)
# Create input with <bos> token (0) at the beginning
# Shape: (batch_size, sequence_length)
input_seq = torch.zeros(batch_size, sequence_length, dtype=torch.long, device=device)
input_seq[:, 0] = 0 # <bos> token
input_seq[:, 1:] = data # actual image data
# Create targets (the image tokens to predict)
# Shape: (batch_size, sequence_length-1)
targets = data
# Forward pass
# Shape: (batch_size, sequence_length, vocab_size) -> (batch_size, sequence_length-1, vocab_size)
logits = model(input_seq)[:, :-1, :] # Remove last position's prediction (don't predict <eos>)
# Compute loss
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# Record loss
train_losses.append(loss.item())
epoch_losses.append(loss.item())
if batch_idx % 50 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
# Evaluate on test set after each epoch
test_loss = evaluate_model(model, test_loader, sequence_length, vocab_size, device)
test_losses.append(test_loss)
print(f"Epoch {epoch+1}/{num_epochs} completed. Test Loss: {test_loss:.4f}")
return np.array(train_losses), np.array(test_losses)
def q3_a(train_data, test_data, image_shape, dset_id):
"""
train_data: A (n_train, H, W, 1) uint8 numpy array of color images with values in {0, 1}
test_data: A (n_test, H, W, 1) uint8 numpy array of color images with values in {0, 1}
image_shape: (H, W, 1), height, width, and # of channels of the image
dset_id: An identifying number of which dataset is given (1 or 2). Most likely
used to set different hyperparameters for different datasets
Returns
- a (# of training iterations,) numpy array of train_losses evaluated every minibatch
- a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
- a numpy array of size (100, H, W, 1) of samples with values in {0, 1}
"""
# Hyperparameters
batch_size = 64
learning_rate = 1e-3
num_epochs = 15
# Model parameters as recommended in the instructions
d_model = 128
n_heads = 4
n_layers = 2
# Determine sequence length and vocabulary size
H, W, C = image_shape
sequence_length = H * W * C + 1 # +1 for <bos> token
vocab_size = 2 # Binary images with values in {0, 1}
# Create datasets and data loaders
train_loader = create_dataset(train_data, image_shape, batch_size)
test_loader = create_dataset(test_data, image_shape, batch_size)
# Initialize model and move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
# Train the model
train_losses, test_losses = train_igpt(model, train_loader, test_loader,
sequence_length, vocab_size, device,
num_epochs, learning_rate)
# Generate samples
# save the model
torch.save(model, 'model_no_cache.pth')
samples , _= generate_samples(model, sequence_length, vocab_size, image_shape, device)
return train_losses, test_losses, samples
Results¶
Once you've implemented q3_a, execute the cells below to visualize and save your results
q3ab_save_results(1, 'a', q3_a)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Epoch 1/15, Batch 0/164, Loss: 0.7565 Epoch 1/15, Batch 50/164, Loss: 0.3444 Epoch 1/15, Batch 100/164, Loss: 0.2132 Epoch 1/15, Batch 150/164, Loss: 0.2036 Epoch 1/15 completed. Test Loss: 0.1965 Epoch 2/15, Batch 0/164, Loss: 0.1971 Epoch 2/15, Batch 50/164, Loss: 0.1753 Epoch 2/15, Batch 100/164, Loss: 0.1454 Epoch 2/15, Batch 150/164, Loss: 0.1234 Epoch 2/15 completed. Test Loss: 0.1108 Epoch 3/15, Batch 0/164, Loss: 0.1324 Epoch 3/15, Batch 50/164, Loss: 0.1107 Epoch 3/15, Batch 100/164, Loss: 0.0993 Epoch 3/15, Batch 150/164, Loss: 0.1016 Epoch 3/15 completed. Test Loss: 0.0928 Epoch 4/15, Batch 0/164, Loss: 0.0973 Epoch 4/15, Batch 50/164, Loss: 0.0934 Epoch 4/15, Batch 100/164, Loss: 0.0873 Epoch 4/15, Batch 150/164, Loss: 0.0844 Epoch 4/15 completed. Test Loss: 0.0796 Epoch 5/15, Batch 0/164, Loss: 0.0899 Epoch 5/15, Batch 50/164, Loss: 0.0852 Epoch 5/15, Batch 100/164, Loss: 0.0808 Epoch 5/15, Batch 150/164, Loss: 0.0717 Epoch 5/15 completed. Test Loss: 0.0698 Epoch 6/15, Batch 0/164, Loss: 0.0789 Epoch 6/15, Batch 50/164, Loss: 0.0682 Epoch 6/15, Batch 100/164, Loss: 0.0645 Epoch 6/15, Batch 150/164, Loss: 0.0703 Epoch 6/15 completed. Test Loss: 0.0632 Epoch 7/15, Batch 0/164, Loss: 0.0690 Epoch 7/15, Batch 50/164, Loss: 0.0636 Epoch 7/15, Batch 100/164, Loss: 0.0620 Epoch 7/15, Batch 150/164, Loss: 0.0666 Epoch 7/15 completed. Test Loss: 0.0578 Epoch 8/15, Batch 0/164, Loss: 0.0619 Epoch 8/15, Batch 50/164, Loss: 0.0614 Epoch 8/15, Batch 100/164, Loss: 0.0648 Epoch 8/15, Batch 150/164, Loss: 0.0571 Epoch 8/15 completed. Test Loss: 0.0538 Epoch 9/15, Batch 0/164, Loss: 0.0603 Epoch 9/15, Batch 50/164, Loss: 0.0560 Epoch 9/15, Batch 100/164, Loss: 0.0531 Epoch 9/15, Batch 150/164, Loss: 0.0582 Epoch 9/15 completed. Test Loss: 0.0515 Epoch 10/15, Batch 0/164, Loss: 0.0600 Epoch 10/15, Batch 50/164, Loss: 0.0551 Epoch 10/15, Batch 100/164, Loss: 0.0550 Epoch 10/15, Batch 150/164, Loss: 0.0522 Epoch 10/15 completed. Test Loss: 0.0493 Epoch 11/15, Batch 0/164, Loss: 0.0491 Epoch 11/15, Batch 50/164, Loss: 0.0547 Epoch 11/15, Batch 100/164, Loss: 0.0544 Epoch 11/15, Batch 150/164, Loss: 0.0545 Epoch 11/15 completed. Test Loss: 0.0481 Epoch 12/15, Batch 0/164, Loss: 0.0584 Epoch 12/15, Batch 50/164, Loss: 0.0540 Epoch 12/15, Batch 100/164, Loss: 0.0552 Epoch 12/15, Batch 150/164, Loss: 0.0537 Epoch 12/15 completed. Test Loss: 0.0473 Epoch 13/15, Batch 0/164, Loss: 0.0509 Epoch 13/15, Batch 50/164, Loss: 0.0541 Epoch 13/15, Batch 100/164, Loss: 0.0493 Epoch 13/15, Batch 150/164, Loss: 0.0562 Epoch 13/15 completed. Test Loss: 0.0467 Epoch 14/15, Batch 0/164, Loss: 0.0487 Epoch 14/15, Batch 50/164, Loss: 0.0509 Epoch 14/15, Batch 100/164, Loss: 0.0518 Epoch 14/15, Batch 150/164, Loss: 0.0517 Epoch 14/15 completed. Test Loss: 0.0465 Epoch 15/15, Batch 0/164, Loss: 0.0497 Epoch 15/15, Batch 50/164, Loss: 0.0526 Epoch 15/15, Batch 100/164, Loss: 0.0508 Epoch 15/15, Batch 150/164, Loss: 0.0521 Epoch 15/15 completed. Test Loss: 0.0465 Final Test Loss: 0.0465
samples shape: (100, 20, 20, 1)
q3ab_save_results(2, 'a', q3_a)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Epoch 1/15, Batch 0/938, Loss: 0.6212 Epoch 1/15, Batch 50/938, Loss: 0.2837 Epoch 1/15, Batch 100/938, Loss: 0.2267 Epoch 1/15, Batch 150/938, Loss: 0.2190 Epoch 1/15, Batch 200/938, Loss: 0.2150 Epoch 1/15, Batch 250/938, Loss: 0.2095 Epoch 1/15, Batch 300/938, Loss: 0.1933 Epoch 1/15, Batch 350/938, Loss: 0.1886 Epoch 1/15, Batch 400/938, Loss: 0.1882 Epoch 1/15, Batch 450/938, Loss: 0.1887 Epoch 1/15, Batch 500/938, Loss: 0.1821 Epoch 1/15, Batch 550/938, Loss: 0.1730 Epoch 1/15, Batch 600/938, Loss: 0.1748 Epoch 1/15, Batch 650/938, Loss: 0.1798 Epoch 1/15, Batch 700/938, Loss: 0.1679 Epoch 1/15, Batch 750/938, Loss: 0.1672 Epoch 1/15, Batch 800/938, Loss: 0.1667 Epoch 1/15, Batch 850/938, Loss: 0.1572 Epoch 1/15, Batch 900/938, Loss: 0.1585 Epoch 1/15 completed. Test Loss: 0.1423 Epoch 2/15, Batch 0/938, Loss: 0.1555 Epoch 2/15, Batch 50/938, Loss: 0.1412 Epoch 2/15, Batch 100/938, Loss: 0.1427 Epoch 2/15, Batch 150/938, Loss: 0.1342 Epoch 2/15, Batch 200/938, Loss: 0.1320 Epoch 2/15, Batch 250/938, Loss: 0.1263 Epoch 2/15, Batch 300/938, Loss: 0.1251 Epoch 2/15, Batch 350/938, Loss: 0.1259 Epoch 2/15, Batch 400/938, Loss: 0.1237 Epoch 2/15, Batch 450/938, Loss: 0.1127 Epoch 2/15, Batch 500/938, Loss: 0.1134 Epoch 2/15, Batch 550/938, Loss: 0.1098 Epoch 2/15, Batch 600/938, Loss: 0.1132 Epoch 2/15, Batch 650/938, Loss: 0.1166 Epoch 2/15, Batch 700/938, Loss: 0.1157 Epoch 2/15, Batch 750/938, Loss: 0.1137 Epoch 2/15, Batch 800/938, Loss: 0.1190 Epoch 2/15, Batch 850/938, Loss: 0.1151 Epoch 2/15, Batch 900/938, Loss: 0.1169 Epoch 2/15 completed. Test Loss: 0.1027 Epoch 3/15, Batch 0/938, Loss: 0.1138 Epoch 3/15, Batch 50/938, Loss: 0.1075 Epoch 3/15, Batch 100/938, Loss: 0.1063 Epoch 3/15, Batch 150/938, Loss: 0.1028 Epoch 3/15, Batch 200/938, Loss: 0.1101 Epoch 3/15, Batch 250/938, Loss: 0.1131 Epoch 3/15, Batch 300/938, Loss: 0.1020 Epoch 3/15, Batch 350/938, Loss: 0.1067 Epoch 3/15, Batch 400/938, Loss: 0.1033 Epoch 3/15, Batch 450/938, Loss: 0.1099 Epoch 3/15, Batch 500/938, Loss: 0.1051 Epoch 3/15, Batch 550/938, Loss: 0.1086 Epoch 3/15, Batch 600/938, Loss: 0.1006 Epoch 3/15, Batch 650/938, Loss: 0.1089 Epoch 3/15, Batch 700/938, Loss: 0.1042 Epoch 3/15, Batch 750/938, Loss: 0.1095 Epoch 3/15, Batch 800/938, Loss: 0.1015 Epoch 3/15, Batch 850/938, Loss: 0.0945 Epoch 3/15, Batch 900/938, Loss: 0.1031 Epoch 3/15 completed. Test Loss: 0.0942 Epoch 4/15, Batch 0/938, Loss: 0.1038 Epoch 4/15, Batch 50/938, Loss: 0.0974 Epoch 4/15, Batch 100/938, Loss: 0.0976 Epoch 4/15, Batch 150/938, Loss: 0.1023 Epoch 4/15, Batch 200/938, Loss: 0.1002 Epoch 4/15, Batch 250/938, Loss: 0.0990 Epoch 4/15, Batch 300/938, Loss: 0.1043 Epoch 4/15, Batch 350/938, Loss: 0.1012 Epoch 4/15, Batch 400/938, Loss: 0.1005 Epoch 4/15, Batch 450/938, Loss: 0.1032 Epoch 4/15, Batch 500/938, Loss: 0.0986 Epoch 4/15, Batch 550/938, Loss: 0.0998 Epoch 4/15, Batch 600/938, Loss: 0.0955 Epoch 4/15, Batch 650/938, Loss: 0.0958 Epoch 4/15, Batch 700/938, Loss: 0.0912 Epoch 4/15, Batch 750/938, Loss: 0.0970 Epoch 4/15, Batch 800/938, Loss: 0.0925 Epoch 4/15, Batch 850/938, Loss: 0.0960 Epoch 4/15, Batch 900/938, Loss: 0.0965 Epoch 4/15 completed. Test Loss: 0.0903 Epoch 5/15, Batch 0/938, Loss: 0.0977 Epoch 5/15, Batch 50/938, Loss: 0.0946 Epoch 5/15, Batch 100/938, Loss: 0.0898 Epoch 5/15, Batch 150/938, Loss: 0.0943 Epoch 5/15, Batch 200/938, Loss: 0.0954 Epoch 5/15, Batch 250/938, Loss: 0.0882 Epoch 5/15, Batch 300/938, Loss: 0.0968 Epoch 5/15, Batch 350/938, Loss: 0.0936 Epoch 5/15, Batch 400/938, Loss: 0.0939 Epoch 5/15, Batch 450/938, Loss: 0.0911 Epoch 5/15, Batch 500/938, Loss: 0.0911 Epoch 5/15, Batch 550/938, Loss: 0.0944 Epoch 5/15, Batch 600/938, Loss: 0.0946 Epoch 5/15, Batch 650/938, Loss: 0.0917 Epoch 5/15, Batch 700/938, Loss: 0.0940 Epoch 5/15, Batch 750/938, Loss: 0.0956 Epoch 5/15, Batch 800/938, Loss: 0.0923 Epoch 5/15, Batch 850/938, Loss: 0.0938 Epoch 5/15, Batch 900/938, Loss: 0.0873 Epoch 5/15 completed. Test Loss: 0.0880 Epoch 6/15, Batch 0/938, Loss: 0.0955 Epoch 6/15, Batch 50/938, Loss: 0.0928 Epoch 6/15, Batch 100/938, Loss: 0.0969 Epoch 6/15, Batch 150/938, Loss: 0.0914 Epoch 6/15, Batch 200/938, Loss: 0.0913 Epoch 6/15, Batch 250/938, Loss: 0.0922 Epoch 6/15, Batch 300/938, Loss: 0.0900 Epoch 6/15, Batch 350/938, Loss: 0.0939 Epoch 6/15, Batch 400/938, Loss: 0.0883 Epoch 6/15, Batch 450/938, Loss: 0.0998 Epoch 6/15, Batch 500/938, Loss: 0.0911 Epoch 6/15, Batch 550/938, Loss: 0.0918 Epoch 6/15, Batch 600/938, Loss: 0.0913 Epoch 6/15, Batch 650/938, Loss: 0.0941 Epoch 6/15, Batch 700/938, Loss: 0.0936 Epoch 6/15, Batch 750/938, Loss: 0.0923 Epoch 6/15, Batch 800/938, Loss: 0.0869 Epoch 6/15, Batch 850/938, Loss: 0.0854 Epoch 6/15, Batch 900/938, Loss: 0.0890 Epoch 6/15 completed. Test Loss: 0.0852 Epoch 7/15, Batch 0/938, Loss: 0.0907 Epoch 7/15, Batch 50/938, Loss: 0.0869 Epoch 7/15, Batch 100/938, Loss: 0.0891 Epoch 7/15, Batch 150/938, Loss: 0.0905 Epoch 7/15, Batch 200/938, Loss: 0.0886 Epoch 7/15, Batch 250/938, Loss: 0.0889 Epoch 7/15, Batch 300/938, Loss: 0.0960 Epoch 7/15, Batch 350/938, Loss: 0.0868 Epoch 7/15, Batch 400/938, Loss: 0.0964 Epoch 7/15, Batch 450/938, Loss: 0.0897 Epoch 7/15, Batch 500/938, Loss: 0.0885 Epoch 7/15, Batch 550/938, Loss: 0.0915 Epoch 7/15, Batch 600/938, Loss: 0.0906 Epoch 7/15, Batch 650/938, Loss: 0.0864 Epoch 7/15, Batch 700/938, Loss: 0.0852 Epoch 7/15, Batch 750/938, Loss: 0.0880 Epoch 7/15, Batch 800/938, Loss: 0.0888 Epoch 7/15, Batch 850/938, Loss: 0.0813 Epoch 7/15, Batch 900/938, Loss: 0.0898 Epoch 7/15 completed. Test Loss: 0.0835 Epoch 8/15, Batch 0/938, Loss: 0.0873 Epoch 8/15, Batch 50/938, Loss: 0.0872 Epoch 8/15, Batch 100/938, Loss: 0.0908 Epoch 8/15, Batch 150/938, Loss: 0.0904 Epoch 8/15, Batch 200/938, Loss: 0.0869 Epoch 8/15, Batch 250/938, Loss: 0.0844 Epoch 8/15, Batch 300/938, Loss: 0.0903 Epoch 8/15, Batch 350/938, Loss: 0.0834 Epoch 8/15, Batch 400/938, Loss: 0.0842 Epoch 8/15, Batch 450/938, Loss: 0.0839 Epoch 8/15, Batch 500/938, Loss: 0.0870 Epoch 8/15, Batch 550/938, Loss: 0.0879 Epoch 8/15, Batch 600/938, Loss: 0.0842 Epoch 8/15, Batch 650/938, Loss: 0.0883 Epoch 8/15, Batch 700/938, Loss: 0.0861 Epoch 8/15, Batch 750/938, Loss: 0.0850 Epoch 8/15, Batch 800/938, Loss: 0.0797 Epoch 8/15, Batch 850/938, Loss: 0.0834 Epoch 8/15, Batch 900/938, Loss: 0.0931 Epoch 8/15 completed. Test Loss: 0.0828 Epoch 9/15, Batch 0/938, Loss: 0.0853 Epoch 9/15, Batch 50/938, Loss: 0.0869 Epoch 9/15, Batch 100/938, Loss: 0.0900 Epoch 9/15, Batch 150/938, Loss: 0.0850 Epoch 9/15, Batch 200/938, Loss: 0.0873 Epoch 9/15, Batch 250/938, Loss: 0.0848 Epoch 9/15, Batch 300/938, Loss: 0.0917 Epoch 9/15, Batch 350/938, Loss: 0.0872 Epoch 9/15, Batch 400/938, Loss: 0.0863 Epoch 9/15, Batch 450/938, Loss: 0.0914 Epoch 9/15, Batch 500/938, Loss: 0.0878 Epoch 9/15, Batch 550/938, Loss: 0.0839 Epoch 9/15, Batch 600/938, Loss: 0.0850 Epoch 9/15, Batch 650/938, Loss: 0.0913 Epoch 9/15, Batch 700/938, Loss: 0.0865 Epoch 9/15, Batch 750/938, Loss: 0.0877 Epoch 9/15, Batch 800/938, Loss: 0.0853 Epoch 9/15, Batch 850/938, Loss: 0.0892 Epoch 9/15, Batch 900/938, Loss: 0.0905 Epoch 9/15 completed. Test Loss: 0.0817 Epoch 10/15, Batch 0/938, Loss: 0.0804 Epoch 10/15, Batch 50/938, Loss: 0.0857 Epoch 10/15, Batch 100/938, Loss: 0.0792 Epoch 10/15, Batch 150/938, Loss: 0.0889 Epoch 10/15, Batch 200/938, Loss: 0.0847 Epoch 10/15, Batch 250/938, Loss: 0.0863 Epoch 10/15, Batch 300/938, Loss: 0.0876 Epoch 10/15, Batch 350/938, Loss: 0.0854 Epoch 10/15, Batch 400/938, Loss: 0.0824 Epoch 10/15, Batch 450/938, Loss: 0.0854 Epoch 10/15, Batch 500/938, Loss: 0.0889 Epoch 10/15, Batch 550/938, Loss: 0.0803 Epoch 10/15, Batch 600/938, Loss: 0.0826 Epoch 10/15, Batch 650/938, Loss: 0.0901 Epoch 10/15, Batch 700/938, Loss: 0.0889 Epoch 10/15, Batch 750/938, Loss: 0.0833 Epoch 10/15, Batch 800/938, Loss: 0.0811 Epoch 10/15, Batch 850/938, Loss: 0.0868 Epoch 10/15, Batch 900/938, Loss: 0.0864 Epoch 10/15 completed. Test Loss: 0.0806 Epoch 11/15, Batch 0/938, Loss: 0.0873 Epoch 11/15, Batch 50/938, Loss: 0.0871 Epoch 11/15, Batch 100/938, Loss: 0.0860 Epoch 11/15, Batch 150/938, Loss: 0.0835 Epoch 11/15, Batch 200/938, Loss: 0.0884 Epoch 11/15, Batch 250/938, Loss: 0.0830 Epoch 11/15, Batch 300/938, Loss: 0.0852 Epoch 11/15, Batch 350/938, Loss: 0.0793 Epoch 11/15, Batch 400/938, Loss: 0.0858 Epoch 11/15, Batch 450/938, Loss: 0.0822 Epoch 11/15, Batch 500/938, Loss: 0.0810 Epoch 11/15, Batch 550/938, Loss: 0.0882 Epoch 11/15, Batch 600/938, Loss: 0.0854 Epoch 11/15, Batch 650/938, Loss: 0.0831 Epoch 11/15, Batch 700/938, Loss: 0.0854 Epoch 11/15, Batch 750/938, Loss: 0.0891 Epoch 11/15, Batch 800/938, Loss: 0.0845 Epoch 11/15, Batch 850/938, Loss: 0.0868 Epoch 11/15, Batch 900/938, Loss: 0.0889 Epoch 11/15 completed. Test Loss: 0.0800 Epoch 12/15, Batch 0/938, Loss: 0.0884 Epoch 12/15, Batch 50/938, Loss: 0.0879 Epoch 12/15, Batch 100/938, Loss: 0.0860 Epoch 12/15, Batch 150/938, Loss: 0.0785 Epoch 12/15, Batch 200/938, Loss: 0.0811 Epoch 12/15, Batch 250/938, Loss: 0.0837 Epoch 12/15, Batch 300/938, Loss: 0.0837 Epoch 12/15, Batch 350/938, Loss: 0.0862 Epoch 12/15, Batch 400/938, Loss: 0.0783 Epoch 12/15, Batch 450/938, Loss: 0.0838 Epoch 12/15, Batch 500/938, Loss: 0.0764 Epoch 12/15, Batch 550/938, Loss: 0.0862 Epoch 12/15, Batch 600/938, Loss: 0.0848 Epoch 12/15, Batch 650/938, Loss: 0.0791 Epoch 12/15, Batch 700/938, Loss: 0.0868 Epoch 12/15, Batch 750/938, Loss: 0.0830 Epoch 12/15, Batch 800/938, Loss: 0.0882 Epoch 12/15, Batch 850/938, Loss: 0.0853 Epoch 12/15, Batch 900/938, Loss: 0.0819 Epoch 12/15 completed. Test Loss: 0.0795 Epoch 13/15, Batch 0/938, Loss: 0.0900 Epoch 13/15, Batch 50/938, Loss: 0.0815 Epoch 13/15, Batch 100/938, Loss: 0.0791 Epoch 13/15, Batch 150/938, Loss: 0.0890 Epoch 13/15, Batch 200/938, Loss: 0.0841 Epoch 13/15, Batch 250/938, Loss: 0.0842 Epoch 13/15, Batch 300/938, Loss: 0.0799 Epoch 13/15, Batch 350/938, Loss: 0.0808 Epoch 13/15, Batch 400/938, Loss: 0.0847 Epoch 13/15, Batch 450/938, Loss: 0.0811 Epoch 13/15, Batch 500/938, Loss: 0.0836 Epoch 13/15, Batch 550/938, Loss: 0.0857 Epoch 13/15, Batch 600/938, Loss: 0.0791 Epoch 13/15, Batch 650/938, Loss: 0.0851 Epoch 13/15, Batch 700/938, Loss: 0.0869 Epoch 13/15, Batch 750/938, Loss: 0.0825 Epoch 13/15, Batch 800/938, Loss: 0.0871 Epoch 13/15, Batch 850/938, Loss: 0.0839 Epoch 13/15, Batch 900/938, Loss: 0.0818 Epoch 13/15 completed. Test Loss: 0.0795 Epoch 14/15, Batch 0/938, Loss: 0.0849 Epoch 14/15, Batch 50/938, Loss: 0.0858 Epoch 14/15, Batch 100/938, Loss: 0.0865 Epoch 14/15, Batch 150/938, Loss: 0.0852 Epoch 14/15, Batch 200/938, Loss: 0.0767 Epoch 14/15, Batch 250/938, Loss: 0.0798 Epoch 14/15, Batch 300/938, Loss: 0.0855 Epoch 14/15, Batch 350/938, Loss: 0.0857 Epoch 14/15, Batch 400/938, Loss: 0.0797 Epoch 14/15, Batch 450/938, Loss: 0.0820 Epoch 14/15, Batch 500/938, Loss: 0.0840 Epoch 14/15, Batch 550/938, Loss: 0.0817 Epoch 14/15, Batch 600/938, Loss: 0.0858 Epoch 14/15, Batch 650/938, Loss: 0.0837 Epoch 14/15, Batch 700/938, Loss: 0.0860 Epoch 14/15, Batch 750/938, Loss: 0.0861 Epoch 14/15, Batch 800/938, Loss: 0.0853 Epoch 14/15, Batch 850/938, Loss: 0.0867 Epoch 14/15, Batch 900/938, Loss: 0.0881 Epoch 14/15 completed. Test Loss: 0.0794 Epoch 15/15, Batch 0/938, Loss: 0.0810 Epoch 15/15, Batch 50/938, Loss: 0.0833 Epoch 15/15, Batch 100/938, Loss: 0.0777 Epoch 15/15, Batch 150/938, Loss: 0.0828 Epoch 15/15, Batch 200/938, Loss: 0.0837 Epoch 15/15, Batch 250/938, Loss: 0.0802 Epoch 15/15, Batch 300/938, Loss: 0.0861 Epoch 15/15, Batch 350/938, Loss: 0.0891 Epoch 15/15, Batch 400/938, Loss: 0.0805 Epoch 15/15, Batch 450/938, Loss: 0.0828 Epoch 15/15, Batch 500/938, Loss: 0.0743 Epoch 15/15, Batch 550/938, Loss: 0.0821 Epoch 15/15, Batch 600/938, Loss: 0.0784 Epoch 15/15, Batch 650/938, Loss: 0.0816 Epoch 15/15, Batch 700/938, Loss: 0.0816 Epoch 15/15, Batch 750/938, Loss: 0.0760 Epoch 15/15, Batch 800/938, Loss: 0.0820 Epoch 15/15, Batch 850/938, Loss: 0.0833 Epoch 15/15, Batch 900/938, Loss: 0.0869 Epoch 15/15 completed. Test Loss: 0.0793 Final Test Loss: 0.0793
samples shape: (100, 28, 28, 1)
Part (b) iGPT on Colored Shapes and MNIST¶
Now, implement an iGPT that models color. In order to reduce the length of token sequences, iGPT models each RGB pixel as a single token. This effectively reduces the context length from HWC to just H*W. iGPT does this through a k-means clustering approach. Because our images only each can only take on 4 values (2 bits) per channel, we can represent each pixel with 64 values (6 bits). Convert the dataset into an image of tokens and train iGPT on the colored shapes and MNIST dataset.
Checkout the iGPT paper for more details: Generative Pretraining from Pixels
Training times and hyperparameter settings should be the same as part (a), except train for longer (15 epochs)
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- 100 samples from the final trained model
def q3_b(train_data, test_data, image_shape, dset_id):
"""
train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
image_shape: (H, W, C), height, width, and # of channels of the image
dset_id: An identifying number of which dataset is given (1 or 2). Most likely
used to set different hyperparameters for different datasets
Returns
- a (# of training iterations,) numpy array of train_losses evaluated every minibatch
- a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
- a numpy array of size (100, H, W, C) of samples with values in {0, 1, 2, 3}
"""
batch_size = 64
learning_rate = 1e-3
num_epochs = 15
# Model parameters as recommended in the instructions
d_model = 128
n_heads = 4
n_layers = 2
H, W, C = image_shape
sequence_length = H * W + 1 # +1 for <bos> token
vocab_size = 64 # each pixel be represented by 6 bits
# Create datasets and data loaders
train_loader = create_dataset(train_data, image_shape, batch_size)
test_loader = create_dataset(test_data, image_shape, batch_size)
# Initialize model and move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
# Train the model
train_losses, test_losses = train_igpt(model, train_loader, test_loader,
sequence_length, vocab_size, device,
num_epochs, learning_rate)
# save the model
torch.save(model, f'model_colored_no_cache_{dset_id}.pth')
# Generate samples
samples ,_ = generate_samples(model, sequence_length, vocab_size, image_shape, device)
return train_losses, test_losses, samples
Results¶
Once you've implemented q3_b, execute the cells below to visualize and save your results
q3ab_save_results(1, 'b', q3_b)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Epoch 1/15, Batch 0/164, Loss: 4.2606 Epoch 1/15, Batch 50/164, Loss: 3.9352 Epoch 1/15, Batch 100/164, Loss: 2.9784 Epoch 1/15, Batch 150/164, Loss: 1.8770 Epoch 1/15 completed. Test Loss: 1.3999 Epoch 2/15, Batch 0/164, Loss: 1.5793 Epoch 2/15, Batch 50/164, Loss: 0.7684 Epoch 2/15, Batch 100/164, Loss: 0.4597 Epoch 2/15, Batch 150/164, Loss: 0.3514 Epoch 2/15 completed. Test Loss: 0.2962 Epoch 3/15, Batch 0/164, Loss: 0.3175 Epoch 3/15, Batch 50/164, Loss: 0.2741 Epoch 3/15, Batch 100/164, Loss: 0.2230 Epoch 3/15, Batch 150/164, Loss: 0.1922 Epoch 3/15 completed. Test Loss: 0.1610 Epoch 4/15, Batch 0/164, Loss: 0.1853 Epoch 4/15, Batch 50/164, Loss: 0.1634 Epoch 4/15, Batch 100/164, Loss: 0.1535 Epoch 4/15, Batch 150/164, Loss: 0.1427 Epoch 4/15 completed. Test Loss: 0.1260 Epoch 5/15, Batch 0/164, Loss: 0.1391 Epoch 5/15, Batch 50/164, Loss: 0.1338 Epoch 5/15, Batch 100/164, Loss: 0.1241 Epoch 5/15, Batch 150/164, Loss: 0.1180 Epoch 5/15 completed. Test Loss: 0.1067 Epoch 6/15, Batch 0/164, Loss: 0.1192 Epoch 6/15, Batch 50/164, Loss: 0.1212 Epoch 6/15, Batch 100/164, Loss: 0.1110 Epoch 6/15, Batch 150/164, Loss: 0.1097 Epoch 6/15 completed. Test Loss: 0.0975 Epoch 7/15, Batch 0/164, Loss: 0.1115 Epoch 7/15, Batch 50/164, Loss: 0.1079 Epoch 7/15, Batch 100/164, Loss: 0.1004 Epoch 7/15, Batch 150/164, Loss: 0.0946 Epoch 7/15 completed. Test Loss: 0.0933 Epoch 8/15, Batch 0/164, Loss: 0.1037 Epoch 8/15, Batch 50/164, Loss: 0.0980 Epoch 8/15, Batch 100/164, Loss: 0.0954 Epoch 8/15, Batch 150/164, Loss: 0.1046 Epoch 8/15 completed. Test Loss: 0.0880 Epoch 9/15, Batch 0/164, Loss: 0.1014 Epoch 9/15, Batch 50/164, Loss: 0.0952 Epoch 9/15, Batch 100/164, Loss: 0.1003 Epoch 9/15, Batch 150/164, Loss: 0.0909 Epoch 9/15 completed. Test Loss: 0.0850 Epoch 10/15, Batch 0/164, Loss: 0.0897 Epoch 10/15, Batch 50/164, Loss: 0.0937 Epoch 10/15, Batch 100/164, Loss: 0.0871 Epoch 10/15, Batch 150/164, Loss: 0.0910 Epoch 10/15 completed. Test Loss: 0.0824 Epoch 11/15, Batch 0/164, Loss: 0.0855 Epoch 11/15, Batch 50/164, Loss: 0.0893 Epoch 11/15, Batch 100/164, Loss: 0.0896 Epoch 11/15, Batch 150/164, Loss: 0.0850 Epoch 11/15 completed. Test Loss: 0.0815 Epoch 12/15, Batch 0/164, Loss: 0.0802 Epoch 12/15, Batch 50/164, Loss: 0.0910 Epoch 12/15, Batch 100/164, Loss: 0.0933 Epoch 12/15, Batch 150/164, Loss: 0.0842 Epoch 12/15 completed. Test Loss: 0.0794 Epoch 13/15, Batch 0/164, Loss: 0.0824 Epoch 13/15, Batch 50/164, Loss: 0.0888 Epoch 13/15, Batch 100/164, Loss: 0.0850 Epoch 13/15, Batch 150/164, Loss: 0.0849 Epoch 13/15 completed. Test Loss: 0.0783 Epoch 14/15, Batch 0/164, Loss: 0.0828 Epoch 14/15, Batch 50/164, Loss: 0.0879 Epoch 14/15, Batch 100/164, Loss: 0.0794 Epoch 14/15, Batch 150/164, Loss: 0.0887 Epoch 14/15 completed. Test Loss: 0.0779 Epoch 15/15, Batch 0/164, Loss: 0.0864 Epoch 15/15, Batch 50/164, Loss: 0.0795 Epoch 15/15, Batch 100/164, Loss: 0.0850 Epoch 15/15, Batch 150/164, Loss: 0.0844 Epoch 15/15 completed. Test Loss: 0.0778 Final Test Loss: 0.0778
samples shape: (100, 20, 20, 3)
q3ab_save_results(2, 'b', q3_b)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Epoch 1/15, Batch 0/938, Loss: 4.1959 Epoch 1/15, Batch 50/938, Loss: 3.3272 Epoch 1/15, Batch 100/938, Loss: 1.6369 Epoch 1/15, Batch 150/938, Loss: 1.1212 Epoch 1/15, Batch 200/938, Loss: 0.9332 Epoch 1/15, Batch 250/938, Loss: 0.8661 Epoch 1/15, Batch 300/938, Loss: 0.8122 Epoch 1/15, Batch 350/938, Loss: 0.8091 Epoch 1/15, Batch 400/938, Loss: 0.7632 Epoch 1/15, Batch 450/938, Loss: 0.7805 Epoch 1/15, Batch 500/938, Loss: 0.7259 Epoch 1/15, Batch 550/938, Loss: 0.6978 Epoch 1/15, Batch 600/938, Loss: 0.6527 Epoch 1/15, Batch 650/938, Loss: 0.6754 Epoch 1/15, Batch 700/938, Loss: 0.6195 Epoch 1/15, Batch 750/938, Loss: 0.6289 Epoch 1/15, Batch 800/938, Loss: 0.6234 Epoch 1/15, Batch 850/938, Loss: 0.6313 Epoch 1/15, Batch 900/938, Loss: 0.6226 Epoch 1/15 completed. Test Loss: 0.5847 Epoch 2/15, Batch 0/938, Loss: 0.5990 Epoch 2/15, Batch 50/938, Loss: 0.6081 Epoch 2/15, Batch 100/938, Loss: 0.6120 Epoch 2/15, Batch 150/938, Loss: 0.6272 Epoch 2/15, Batch 200/938, Loss: 0.5617 Epoch 2/15, Batch 250/938, Loss: 0.5834 Epoch 2/15, Batch 300/938, Loss: 0.5733 Epoch 2/15, Batch 350/938, Loss: 0.5822 Epoch 2/15, Batch 400/938, Loss: 0.5571 Epoch 2/15, Batch 450/938, Loss: 0.5379 Epoch 2/15, Batch 500/938, Loss: 0.5412 Epoch 2/15, Batch 550/938, Loss: 0.5793 Epoch 2/15, Batch 600/938, Loss: 0.5047 Epoch 2/15, Batch 650/938, Loss: 0.5169 Epoch 2/15, Batch 700/938, Loss: 0.5281 Epoch 2/15, Batch 750/938, Loss: 0.5090 Epoch 2/15, Batch 800/938, Loss: 0.5111 Epoch 2/15, Batch 850/938, Loss: 0.5154 Epoch 2/15, Batch 900/938, Loss: 0.4967 Epoch 2/15 completed. Test Loss: 0.4656 Epoch 3/15, Batch 0/938, Loss: 0.4917 Epoch 3/15, Batch 50/938, Loss: 0.5034 Epoch 3/15, Batch 100/938, Loss: 0.4944 Epoch 3/15, Batch 150/938, Loss: 0.4807 Epoch 3/15, Batch 200/938, Loss: 0.4832 Epoch 3/15, Batch 250/938, Loss: 0.4999 Epoch 3/15, Batch 300/938, Loss: 0.4862 Epoch 3/15, Batch 350/938, Loss: 0.4662 Epoch 3/15, Batch 400/938, Loss: 0.4932 Epoch 3/15, Batch 450/938, Loss: 0.4692 Epoch 3/15, Batch 500/938, Loss: 0.4872 Epoch 3/15, Batch 550/938, Loss: 0.4541 Epoch 3/15, Batch 600/938, Loss: 0.4566 Epoch 3/15, Batch 650/938, Loss: 0.4587 Epoch 3/15, Batch 700/938, Loss: 0.4554 Epoch 3/15, Batch 750/938, Loss: 0.4549 Epoch 3/15, Batch 800/938, Loss: 0.4302 Epoch 3/15, Batch 850/938, Loss: 0.4279 Epoch 3/15, Batch 900/938, Loss: 0.4309 Epoch 3/15 completed. Test Loss: 0.4147 Epoch 4/15, Batch 0/938, Loss: 0.4480 Epoch 4/15, Batch 50/938, Loss: 0.4450 Epoch 4/15, Batch 100/938, Loss: 0.4297 Epoch 4/15, Batch 150/938, Loss: 0.4473 Epoch 4/15, Batch 200/938, Loss: 0.4273 Epoch 4/15, Batch 250/938, Loss: 0.4424 Epoch 4/15, Batch 300/938, Loss: 0.4214 Epoch 4/15, Batch 350/938, Loss: 0.4466 Epoch 4/15, Batch 400/938, Loss: 0.4289 Epoch 4/15, Batch 450/938, Loss: 0.4011 Epoch 4/15, Batch 500/938, Loss: 0.4159 Epoch 4/15, Batch 550/938, Loss: 0.4178 Epoch 4/15, Batch 600/938, Loss: 0.4113 Epoch 4/15, Batch 650/938, Loss: 0.3959 Epoch 4/15, Batch 700/938, Loss: 0.4184 Epoch 4/15, Batch 750/938, Loss: 0.4069 Epoch 4/15, Batch 800/938, Loss: 0.4161 Epoch 4/15, Batch 850/938, Loss: 0.4211 Epoch 4/15, Batch 900/938, Loss: 0.3780 Epoch 4/15 completed. Test Loss: 0.3693 Epoch 5/15, Batch 0/938, Loss: 0.4028 Epoch 5/15, Batch 50/938, Loss: 0.4012 Epoch 5/15, Batch 100/938, Loss: 0.3897 Epoch 5/15, Batch 150/938, Loss: 0.3895 Epoch 5/15, Batch 200/938, Loss: 0.4058 Epoch 5/15, Batch 250/938, Loss: 0.3899 Epoch 5/15, Batch 300/938, Loss: 0.3730 Epoch 5/15, Batch 350/938, Loss: 0.3881 Epoch 5/15, Batch 400/938, Loss: 0.3705 Epoch 5/15, Batch 450/938, Loss: 0.3675 Epoch 5/15, Batch 500/938, Loss: 0.3640 Epoch 5/15, Batch 550/938, Loss: 0.3794 Epoch 5/15, Batch 600/938, Loss: 0.3628 Epoch 5/15, Batch 650/938, Loss: 0.3812 Epoch 5/15, Batch 700/938, Loss: 0.3693 Epoch 5/15, Batch 750/938, Loss: 0.3612 Epoch 5/15, Batch 800/938, Loss: 0.3775 Epoch 5/15, Batch 850/938, Loss: 0.3655 Epoch 5/15, Batch 900/938, Loss: 0.3602 Epoch 5/15 completed. Test Loss: 0.3128 Epoch 6/15, Batch 0/938, Loss: 0.3612 Epoch 6/15, Batch 50/938, Loss: 0.3532 Epoch 6/15, Batch 100/938, Loss: 0.3500 Epoch 6/15, Batch 150/938, Loss: 0.3439 Epoch 6/15, Batch 200/938, Loss: 0.3574 Epoch 6/15, Batch 250/938, Loss: 0.3536 Epoch 6/15, Batch 300/938, Loss: 0.3484 Epoch 6/15, Batch 350/938, Loss: 0.3546 Epoch 6/15, Batch 400/938, Loss: 0.3418 Epoch 6/15, Batch 450/938, Loss: 0.3334 Epoch 6/15, Batch 500/938, Loss: 0.3356 Epoch 6/15, Batch 550/938, Loss: 0.3387 Epoch 6/15, Batch 600/938, Loss: 0.3421 Epoch 6/15, Batch 650/938, Loss: 0.3313 Epoch 6/15, Batch 700/938, Loss: 0.3208 Epoch 6/15, Batch 750/938, Loss: 0.3231 Epoch 6/15, Batch 800/938, Loss: 0.3225 Epoch 6/15, Batch 850/938, Loss: 0.3241 Epoch 6/15, Batch 900/938, Loss: 0.3156 Epoch 6/15 completed. Test Loss: 0.2667 Epoch 7/15, Batch 0/938, Loss: 0.3260 Epoch 7/15, Batch 50/938, Loss: 0.3193 Epoch 7/15, Batch 100/938, Loss: 0.3169 Epoch 7/15, Batch 150/938, Loss: 0.3166 Epoch 7/15, Batch 200/938, Loss: 0.3136 Epoch 7/15, Batch 250/938, Loss: 0.2925 Epoch 7/15, Batch 300/938, Loss: 0.3076 Epoch 7/15, Batch 350/938, Loss: 0.2909 Epoch 7/15, Batch 400/938, Loss: 0.2994 Epoch 7/15, Batch 450/938, Loss: 0.2952 Epoch 7/15, Batch 500/938, Loss: 0.3022 Epoch 7/15, Batch 550/938, Loss: 0.3047 Epoch 7/15, Batch 600/938, Loss: 0.3051 Epoch 7/15, Batch 650/938, Loss: 0.2924 Epoch 7/15, Batch 700/938, Loss: 0.2902 Epoch 7/15, Batch 750/938, Loss: 0.3053 Epoch 7/15, Batch 800/938, Loss: 0.2979 Epoch 7/15, Batch 850/938, Loss: 0.2890 Epoch 7/15, Batch 900/938, Loss: 0.2833 Epoch 7/15 completed. Test Loss: 0.2301 Epoch 8/15, Batch 0/938, Loss: 0.2992 Epoch 8/15, Batch 50/938, Loss: 0.2773 Epoch 8/15, Batch 100/938, Loss: 0.2916 Epoch 8/15, Batch 150/938, Loss: 0.2849 Epoch 8/15, Batch 200/938, Loss: 0.2839 Epoch 8/15, Batch 250/938, Loss: 0.2995 Epoch 8/15, Batch 300/938, Loss: 0.2734 Epoch 8/15, Batch 350/938, Loss: 0.2864 Epoch 8/15, Batch 400/938, Loss: 0.2846 Epoch 8/15, Batch 450/938, Loss: 0.2717 Epoch 8/15, Batch 500/938, Loss: 0.2778 Epoch 8/15, Batch 550/938, Loss: 0.2761 Epoch 8/15, Batch 600/938, Loss: 0.2738 Epoch 8/15, Batch 650/938, Loss: 0.2879 Epoch 8/15, Batch 700/938, Loss: 0.2704 Epoch 8/15, Batch 750/938, Loss: 0.2780 Epoch 8/15, Batch 800/938, Loss: 0.2679 Epoch 8/15, Batch 850/938, Loss: 0.2630 Epoch 8/15, Batch 900/938, Loss: 0.2613 Epoch 8/15 completed. Test Loss: 0.2041 Epoch 9/15, Batch 0/938, Loss: 0.2680 Epoch 9/15, Batch 50/938, Loss: 0.2633 Epoch 9/15, Batch 100/938, Loss: 0.2661 Epoch 9/15, Batch 150/938, Loss: 0.2719 Epoch 9/15, Batch 200/938, Loss: 0.2641 Epoch 9/15, Batch 250/938, Loss: 0.2584 Epoch 9/15, Batch 300/938, Loss: 0.2552 Epoch 9/15, Batch 350/938, Loss: 0.2596 Epoch 9/15, Batch 400/938, Loss: 0.2643 Epoch 9/15, Batch 450/938, Loss: 0.2499 Epoch 9/15, Batch 500/938, Loss: 0.2623 Epoch 9/15, Batch 550/938, Loss: 0.2594 Epoch 9/15, Batch 600/938, Loss: 0.2533 Epoch 9/15, Batch 650/938, Loss: 0.2577 Epoch 9/15, Batch 700/938, Loss: 0.2555 Epoch 9/15, Batch 750/938, Loss: 0.2440 Epoch 9/15, Batch 800/938, Loss: 0.2421 Epoch 9/15, Batch 850/938, Loss: 0.2551 Epoch 9/15, Batch 900/938, Loss: 0.2501 Epoch 9/15 completed. Test Loss: 0.1880 Epoch 10/15, Batch 0/938, Loss: 0.2460 Epoch 10/15, Batch 50/938, Loss: 0.2493 Epoch 10/15, Batch 100/938, Loss: 0.2540 Epoch 10/15, Batch 150/938, Loss: 0.2462 Epoch 10/15, Batch 200/938, Loss: 0.2484 Epoch 10/15, Batch 250/938, Loss: 0.2506 Epoch 10/15, Batch 300/938, Loss: 0.2444 Epoch 10/15, Batch 350/938, Loss: 0.2459 Epoch 10/15, Batch 400/938, Loss: 0.2455 Epoch 10/15, Batch 450/938, Loss: 0.2488 Epoch 10/15, Batch 500/938, Loss: 0.2387 Epoch 10/15, Batch 550/938, Loss: 0.2427 Epoch 10/15, Batch 600/938, Loss: 0.2396 Epoch 10/15, Batch 650/938, Loss: 0.2390 Epoch 10/15, Batch 700/938, Loss: 0.2406 Epoch 10/15, Batch 750/938, Loss: 0.2400 Epoch 10/15, Batch 800/938, Loss: 0.2396 Epoch 10/15, Batch 850/938, Loss: 0.2486 Epoch 10/15, Batch 900/938, Loss: 0.2500 Epoch 10/15 completed. Test Loss: 0.1780 Epoch 11/15, Batch 0/938, Loss: 0.2323 Epoch 11/15, Batch 50/938, Loss: 0.2357 Epoch 11/15, Batch 100/938, Loss: 0.2415 Epoch 11/15, Batch 150/938, Loss: 0.2311 Epoch 11/15, Batch 200/938, Loss: 0.2305 Epoch 11/15, Batch 250/938, Loss: 0.2324 Epoch 11/15, Batch 300/938, Loss: 0.2336 Epoch 11/15, Batch 350/938, Loss: 0.2358 Epoch 11/15, Batch 400/938, Loss: 0.2434 Epoch 11/15, Batch 450/938, Loss: 0.2302 Epoch 11/15, Batch 500/938, Loss: 0.2353 Epoch 11/15, Batch 550/938, Loss: 0.2304 Epoch 11/15, Batch 600/938, Loss: 0.2422 Epoch 11/15, Batch 650/938, Loss: 0.2369 Epoch 11/15, Batch 700/938, Loss: 0.2270 Epoch 11/15, Batch 750/938, Loss: 0.2330 Epoch 11/15, Batch 800/938, Loss: 0.2259 Epoch 11/15, Batch 850/938, Loss: 0.2353 Epoch 11/15, Batch 900/938, Loss: 0.2394 Epoch 11/15 completed. Test Loss: 0.1705 Epoch 12/15, Batch 0/938, Loss: 0.2365 Epoch 12/15, Batch 50/938, Loss: 0.2298 Epoch 12/15, Batch 100/938, Loss: 0.2311 Epoch 12/15, Batch 150/938, Loss: 0.2342 Epoch 12/15, Batch 200/938, Loss: 0.2210 Epoch 12/15, Batch 250/938, Loss: 0.2292 Epoch 12/15, Batch 300/938, Loss: 0.2256 Epoch 12/15, Batch 350/938, Loss: 0.2278 Epoch 12/15, Batch 400/938, Loss: 0.2214 Epoch 12/15, Batch 450/938, Loss: 0.2271 Epoch 12/15, Batch 500/938, Loss: 0.2241 Epoch 12/15, Batch 550/938, Loss: 0.2269 Epoch 12/15, Batch 600/938, Loss: 0.2293 Epoch 12/15, Batch 650/938, Loss: 0.2342 Epoch 12/15, Batch 700/938, Loss: 0.2326 Epoch 12/15, Batch 750/938, Loss: 0.2253 Epoch 12/15, Batch 800/938, Loss: 0.2288 Epoch 12/15, Batch 850/938, Loss: 0.2196 Epoch 12/15, Batch 900/938, Loss: 0.2273 Epoch 12/15 completed. Test Loss: 0.1653 Epoch 13/15, Batch 0/938, Loss: 0.2240 Epoch 13/15, Batch 50/938, Loss: 0.2260 Epoch 13/15, Batch 100/938, Loss: 0.2244 Epoch 13/15, Batch 150/938, Loss: 0.2247 Epoch 13/15, Batch 200/938, Loss: 0.2240 Epoch 13/15, Batch 250/938, Loss: 0.2233 Epoch 13/15, Batch 300/938, Loss: 0.2121 Epoch 13/15, Batch 350/938, Loss: 0.2210 Epoch 13/15, Batch 400/938, Loss: 0.2178 Epoch 13/15, Batch 450/938, Loss: 0.2263 Epoch 13/15, Batch 500/938, Loss: 0.2255 Epoch 13/15, Batch 550/938, Loss: 0.2269 Epoch 13/15, Batch 600/938, Loss: 0.2210 Epoch 13/15, Batch 650/938, Loss: 0.2189 Epoch 13/15, Batch 700/938, Loss: 0.2343 Epoch 13/15, Batch 750/938, Loss: 0.2299 Epoch 13/15, Batch 800/938, Loss: 0.2213 Epoch 13/15, Batch 850/938, Loss: 0.2170 Epoch 13/15, Batch 900/938, Loss: 0.2248 Epoch 13/15 completed. Test Loss: 0.1622 Epoch 14/15, Batch 0/938, Loss: 0.2273 Epoch 14/15, Batch 50/938, Loss: 0.2216 Epoch 14/15, Batch 100/938, Loss: 0.2272 Epoch 14/15, Batch 150/938, Loss: 0.2236 Epoch 14/15, Batch 200/938, Loss: 0.2252 Epoch 14/15, Batch 250/938, Loss: 0.2224 Epoch 14/15, Batch 300/938, Loss: 0.2252 Epoch 14/15, Batch 350/938, Loss: 0.2267 Epoch 14/15, Batch 400/938, Loss: 0.2265 Epoch 14/15, Batch 450/938, Loss: 0.2222 Epoch 14/15, Batch 500/938, Loss: 0.2239 Epoch 14/15, Batch 550/938, Loss: 0.2202 Epoch 14/15, Batch 600/938, Loss: 0.2263 Epoch 14/15, Batch 650/938, Loss: 0.2195 Epoch 14/15, Batch 700/938, Loss: 0.2183 Epoch 14/15, Batch 750/938, Loss: 0.2260 Epoch 14/15, Batch 800/938, Loss: 0.2205 Epoch 14/15, Batch 850/938, Loss: 0.2257 Epoch 14/15, Batch 900/938, Loss: 0.2245 Epoch 14/15 completed. Test Loss: 0.1611 Epoch 15/15, Batch 0/938, Loss: 0.2189 Epoch 15/15, Batch 50/938, Loss: 0.2202 Epoch 15/15, Batch 100/938, Loss: 0.2225 Epoch 15/15, Batch 150/938, Loss: 0.2225 Epoch 15/15, Batch 200/938, Loss: 0.2239 Epoch 15/15, Batch 250/938, Loss: 0.2229 Epoch 15/15, Batch 300/938, Loss: 0.2234 Epoch 15/15, Batch 350/938, Loss: 0.2252 Epoch 15/15, Batch 400/938, Loss: 0.2178 Epoch 15/15, Batch 450/938, Loss: 0.2194 Epoch 15/15, Batch 500/938, Loss: 0.2283 Epoch 15/15, Batch 550/938, Loss: 0.2179 Epoch 15/15, Batch 600/938, Loss: 0.2217 Epoch 15/15, Batch 650/938, Loss: 0.2178 Epoch 15/15, Batch 700/938, Loss: 0.2236 Epoch 15/15, Batch 750/938, Loss: 0.2199 Epoch 15/15, Batch 800/938, Loss: 0.2239 Epoch 15/15, Batch 850/938, Loss: 0.2285 Epoch 15/15, Batch 900/938, Loss: 0.2182 Epoch 15/15 completed. Test Loss: 0.1610 Final Test Loss: 0.1610
samples shape: (100, 28, 28, 3)
Part (c) K, V Caching for Improved Inference¶
You may have noticed that generation from the transformer is quite slow. Part of this is just due to the autoregressive nature. However, another part is due to some computational inefficiency. At each forward pass of the model, we are performing repeat computation of the past sequence. Specifically, we can cache the key and values at the multi attention layer to more quickly predict at each step.
In self-attention, a sequence is processed by generating three vectors for each element in the sequence: a Query (Q), a Key (K), and a Value (V). These vectors are then used to compute attention scores and subsequently the output of the attention layer. Mathematically, this can be represented as:
- For each index $i$, compute $Q_i$, $K_i$, $V_i$ for the current element
- Retrieve $K_{<i}$ and $V_{<i}$ from the cache (where $<i$ denotes all indices before the current one)
- Compute the attention output using $Q_i$, $[K_{<i}, K_i]$, $[V_{<i}, V_i]$
Next implement caching for your transformer to make inference more efficient by modifying your self attention. Use caching for inference in the future problems for faster generation! (Note caching is only used during inference). You will use the same dataset as in part B, dataset 2 of this question (colored mnist). No training is required in this section, feel free to reuse the model you trained in part B, dataset 2.
You will provide these deliverables
- Over the course of inference, measure the time for the forward pass over the total sequence length with and without caching.
- 100 samples from the final trained model using the caching inference pipeline.
def q3_c(train_data, test_data, image_shape, dset_id):
import os
"""
train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
image_shape: (H, W, C), height, width, and # of channels of the image
dset_id: An identifying number of which dataset is given (1 or 2). Most likely
used to set different hyperparameters for different datasets
Returns
- a (# sampling steps,) numpy array of time per sampling iteration, without caching
- a (# sampling steps,) numpy array of time per sampling iteration, with caching
- a numpy array of size (100, H, W, C) of samples with values in {0, 1, 2, 3} (sample generated without caching)
- a numpy array of size (100, H, W, C) of samples with values in {0, 1, 2, 3} (sample generated with caching)
"""
# Model hyperparameters
batch_size = 64
learning_rate = 1e-3
num_epochs = 15
# Transformer architecture parameters
d_model = 128
n_heads = 4
n_layers = 2
H, W, C = image_shape
print("image shape: ", image_shape)
sequence_length = H * W + 1 # +1 for <bos> token
vocab_size = 64 # each pixel represented by 6 bits
# Create datasets and data loaders
train_loader = create_dataset(train_data, image_shape, batch_size)
test_loader = create_dataset(test_data, image_shape, batch_size)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load or train the model
model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
train_losses, test_losses = train_igpt(model, train_loader, test_loader,
sequence_length, vocab_size, device,
num_epochs, learning_rate)
# Generate samples without caching and measure time
# start_time = time.time()
samples_no_cache, time_list_no_cache = generate_samples(model, sequence_length, vocab_size, image_shape, device, use_cache=False, test_mode=False)
# Generate samples with caching and measure time
samples_with_cache, time_list_with_cache = generate_samples(model, sequence_length, vocab_size, image_shape, device, use_cache=True, test_mode=False)
# print(f"Speedup: {total_time_no_cache / total_time_with_cache:.2f}x")
return time_list_no_cache, time_list_with_cache, samples_no_cache, samples_with_cache
Results¶
Once you've implemented q3_c, execute the cells below to visualize and save your results
q3c_save_results(2, q3_c)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data image shape: (28, 28, 3) Epoch 1/15, Batch 0/938, Loss: 4.2718 Epoch 1/15, Batch 50/938, Loss: 3.5423 Epoch 1/15, Batch 100/938, Loss: 1.7781 Epoch 1/15, Batch 150/938, Loss: 1.1359 Epoch 1/15, Batch 200/938, Loss: 0.9425 Epoch 1/15, Batch 250/938, Loss: 0.8162 Epoch 1/15, Batch 300/938, Loss: 0.8538 Epoch 1/15, Batch 350/938, Loss: 0.7669 Epoch 1/15, Batch 400/938, Loss: 0.7949 Epoch 1/15, Batch 450/938, Loss: 0.7809 Epoch 1/15, Batch 500/938, Loss: 0.7572 Epoch 1/15, Batch 550/938, Loss: 0.7087 Epoch 1/15, Batch 600/938, Loss: 0.6779 Epoch 1/15, Batch 650/938, Loss: 0.6718 Epoch 1/15, Batch 700/938, Loss: 0.6426 Epoch 1/15, Batch 750/938, Loss: 0.6188 Epoch 1/15, Batch 800/938, Loss: 0.6319 Epoch 1/15, Batch 850/938, Loss: 0.6013 Epoch 1/15, Batch 900/938, Loss: 0.6063 Epoch 1/15 completed. Test Loss: 0.5818 Epoch 2/15, Batch 0/938, Loss: 0.5846 Epoch 2/15, Batch 50/938, Loss: 0.5855 Epoch 2/15, Batch 100/938, Loss: 0.5834 Epoch 2/15, Batch 150/938, Loss: 0.5655 Epoch 2/15, Batch 200/938, Loss: 0.5744 Epoch 2/15, Batch 250/938, Loss: 0.5381 Epoch 2/15, Batch 300/938, Loss: 0.5357 Epoch 2/15, Batch 350/938, Loss: 0.5642 Epoch 2/15, Batch 400/938, Loss: 0.5569 Epoch 2/15, Batch 450/938, Loss: 0.5557 Epoch 2/15, Batch 500/938, Loss: 0.5126 Epoch 2/15, Batch 550/938, Loss: 0.5367 Epoch 2/15, Batch 600/938, Loss: 0.5142 Epoch 2/15, Batch 650/938, Loss: 0.5057 Epoch 2/15, Batch 700/938, Loss: 0.4898 Epoch 2/15, Batch 750/938, Loss: 0.4923 Epoch 2/15, Batch 800/938, Loss: 0.5167 Epoch 2/15, Batch 850/938, Loss: 0.4860 Epoch 2/15, Batch 900/938, Loss: 0.4951 Epoch 2/15 completed. Test Loss: 0.4632 Epoch 3/15, Batch 0/938, Loss: 0.4827 Epoch 3/15, Batch 50/938, Loss: 0.4902 Epoch 3/15, Batch 100/938, Loss: 0.4752 Epoch 3/15, Batch 150/938, Loss: 0.4781 Epoch 3/15, Batch 200/938, Loss: 0.4612 Epoch 3/15, Batch 250/938, Loss: 0.4779 Epoch 3/15, Batch 300/938, Loss: 0.4838 Epoch 3/15, Batch 350/938, Loss: 0.4754 Epoch 3/15, Batch 400/938, Loss: 0.4845 Epoch 3/15, Batch 450/938, Loss: 0.4675 Epoch 3/15, Batch 500/938, Loss: 0.4605 Epoch 3/15, Batch 550/938, Loss: 0.4582 Epoch 3/15, Batch 600/938, Loss: 0.4550 Epoch 3/15, Batch 650/938, Loss: 0.4428 Epoch 3/15, Batch 700/938, Loss: 0.4609 Epoch 3/15, Batch 750/938, Loss: 0.4617 Epoch 3/15, Batch 800/938, Loss: 0.4544 Epoch 3/15, Batch 850/938, Loss: 0.4646 Epoch 3/15, Batch 900/938, Loss: 0.4584 Epoch 3/15 completed. Test Loss: 0.4146 Epoch 4/15, Batch 0/938, Loss: 0.4440 Epoch 4/15, Batch 50/938, Loss: 0.4569 Epoch 4/15, Batch 100/938, Loss: 0.4488 Epoch 4/15, Batch 150/938, Loss: 0.4396 Epoch 4/15, Batch 200/938, Loss: 0.4434 Epoch 4/15, Batch 250/938, Loss: 0.4332 Epoch 4/15, Batch 300/938, Loss: 0.4298 Epoch 4/15, Batch 350/938, Loss: 0.4441 Epoch 4/15, Batch 400/938, Loss: 0.4122 Epoch 4/15, Batch 450/938, Loss: 0.4090 Epoch 4/15, Batch 500/938, Loss: 0.4317 Epoch 4/15, Batch 550/938, Loss: 0.4174 Epoch 4/15, Batch 600/938, Loss: 0.4154 Epoch 4/15, Batch 650/938, Loss: 0.4222 Epoch 4/15, Batch 700/938, Loss: 0.4183 Epoch 4/15, Batch 750/938, Loss: 0.4125 Epoch 4/15, Batch 800/938, Loss: 0.4120 Epoch 4/15, Batch 850/938, Loss: 0.4163 Epoch 4/15, Batch 900/938, Loss: 0.3967 Epoch 4/15 completed. Test Loss: 0.3685 Epoch 5/15, Batch 0/938, Loss: 0.4068 Epoch 5/15, Batch 50/938, Loss: 0.4104 Epoch 5/15, Batch 100/938, Loss: 0.3983 Epoch 5/15, Batch 150/938, Loss: 0.3979 Epoch 5/15, Batch 200/938, Loss: 0.4037 Epoch 5/15, Batch 250/938, Loss: 0.4143 Epoch 5/15, Batch 300/938, Loss: 0.3918 Epoch 5/15, Batch 350/938, Loss: 0.3931 Epoch 5/15, Batch 400/938, Loss: 0.3846 Epoch 5/15, Batch 450/938, Loss: 0.3951 Epoch 5/15, Batch 500/938, Loss: 0.3836 Epoch 5/15, Batch 550/938, Loss: 0.3932 Epoch 5/15, Batch 600/938, Loss: 0.3683 Epoch 5/15, Batch 650/938, Loss: 0.3887 Epoch 5/15, Batch 700/938, Loss: 0.3801 Epoch 5/15, Batch 750/938, Loss: 0.3589 Epoch 5/15, Batch 800/938, Loss: 0.3982 Epoch 5/15, Batch 850/938, Loss: 0.3848 Epoch 5/15, Batch 900/938, Loss: 0.3697 Epoch 5/15 completed. Test Loss: 0.3261 Epoch 6/15, Batch 0/938, Loss: 0.3788 Epoch 6/15, Batch 50/938, Loss: 0.3716 Epoch 6/15, Batch 100/938, Loss: 0.3850 Epoch 6/15, Batch 150/938, Loss: 0.3589 Epoch 6/15, Batch 200/938, Loss: 0.3540 Epoch 6/15, Batch 250/938, Loss: 0.3504 Epoch 6/15, Batch 300/938, Loss: 0.3534 Epoch 6/15, Batch 350/938, Loss: 0.3633 Epoch 6/15, Batch 400/938, Loss: 0.3642 Epoch 6/15, Batch 450/938, Loss: 0.3338 Epoch 6/15, Batch 500/938, Loss: 0.3560 Epoch 6/15, Batch 550/938, Loss: 0.3586 Epoch 6/15, Batch 600/938, Loss: 0.3443 Epoch 6/15, Batch 650/938, Loss: 0.3518 Epoch 6/15, Batch 700/938, Loss: 0.3498 Epoch 6/15, Batch 750/938, Loss: 0.3511 Epoch 6/15, Batch 800/938, Loss: 0.3532 Epoch 6/15, Batch 850/938, Loss: 0.3329 Epoch 6/15, Batch 900/938, Loss: 0.3504 Epoch 6/15 completed. Test Loss: 0.2919 Epoch 7/15, Batch 0/938, Loss: 0.3376 Epoch 7/15, Batch 50/938, Loss: 0.3346 Epoch 7/15, Batch 100/938, Loss: 0.3414 Epoch 7/15, Batch 150/938, Loss: 0.3408 Epoch 7/15, Batch 200/938, Loss: 0.3296 Epoch 7/15, Batch 250/938, Loss: 0.3435 Epoch 7/15, Batch 300/938, Loss: 0.3412 Epoch 7/15, Batch 350/938, Loss: 0.3284 Epoch 7/15, Batch 400/938, Loss: 0.3412 Epoch 7/15, Batch 450/938, Loss: 0.3272 Epoch 7/15, Batch 500/938, Loss: 0.3277 Epoch 7/15, Batch 550/938, Loss: 0.3270 Epoch 7/15, Batch 600/938, Loss: 0.3257 Epoch 7/15, Batch 650/938, Loss: 0.3213 Epoch 7/15, Batch 700/938, Loss: 0.3168 Epoch 7/15, Batch 750/938, Loss: 0.3126 Epoch 7/15, Batch 800/938, Loss: 0.3260 Epoch 7/15, Batch 850/938, Loss: 0.3258 Epoch 7/15, Batch 900/938, Loss: 0.3150 Epoch 7/15 completed. Test Loss: 0.2624 Epoch 8/15, Batch 0/938, Loss: 0.3114 Epoch 8/15, Batch 50/938, Loss: 0.3081 Epoch 8/15, Batch 100/938, Loss: 0.3103 Epoch 8/15, Batch 150/938, Loss: 0.2964 Epoch 8/15, Batch 200/938, Loss: 0.3144 Epoch 8/15, Batch 250/938, Loss: 0.3217 Epoch 8/15, Batch 300/938, Loss: 0.3127 Epoch 8/15, Batch 350/938, Loss: 0.2814 Epoch 8/15, Batch 400/938, Loss: 0.3045 Epoch 8/15, Batch 450/938, Loss: 0.3139 Epoch 8/15, Batch 500/938, Loss: 0.2926 Epoch 8/15, Batch 550/938, Loss: 0.3015 Epoch 8/15, Batch 600/938, Loss: 0.2943 Epoch 8/15, Batch 650/938, Loss: 0.3153 Epoch 8/15, Batch 700/938, Loss: 0.2873 Epoch 8/15, Batch 750/938, Loss: 0.2926 Epoch 8/15, Batch 800/938, Loss: 0.2910 Epoch 8/15, Batch 850/938, Loss: 0.3043 Epoch 8/15, Batch 900/938, Loss: 0.3068 Epoch 8/15 completed. Test Loss: 0.2360 Epoch 9/15, Batch 0/938, Loss: 0.3018 Epoch 9/15, Batch 50/938, Loss: 0.3065 Epoch 9/15, Batch 100/938, Loss: 0.2952 Epoch 9/15, Batch 150/938, Loss: 0.2899 Epoch 9/15, Batch 200/938, Loss: 0.2881 Epoch 9/15, Batch 250/938, Loss: 0.2879 Epoch 9/15, Batch 300/938, Loss: 0.3020 Epoch 9/15, Batch 350/938, Loss: 0.2875 Epoch 9/15, Batch 400/938, Loss: 0.2785 Epoch 9/15, Batch 450/938, Loss: 0.2853 Epoch 9/15, Batch 500/938, Loss: 0.2670 Epoch 9/15, Batch 550/938, Loss: 0.2865 Epoch 9/15, Batch 600/938, Loss: 0.2784 Epoch 9/15, Batch 650/938, Loss: 0.2787 Epoch 9/15, Batch 700/938, Loss: 0.2908 Epoch 9/15, Batch 750/938, Loss: 0.2792 Epoch 9/15, Batch 800/938, Loss: 0.2856 Epoch 9/15, Batch 850/938, Loss: 0.2766 Epoch 9/15, Batch 900/938, Loss: 0.2750 Epoch 9/15 completed. Test Loss: 0.2178 Epoch 10/15, Batch 0/938, Loss: 0.2697 Epoch 10/15, Batch 50/938, Loss: 0.2798 Epoch 10/15, Batch 100/938, Loss: 0.2720 Epoch 10/15, Batch 150/938, Loss: 0.2746 Epoch 10/15, Batch 200/938, Loss: 0.2831 Epoch 10/15, Batch 250/938, Loss: 0.2805 Epoch 10/15, Batch 300/938, Loss: 0.2684 Epoch 10/15, Batch 350/938, Loss: 0.2604 Epoch 10/15, Batch 400/938, Loss: 0.2655 Epoch 10/15, Batch 450/938, Loss: 0.2691 Epoch 10/15, Batch 500/938, Loss: 0.2618 Epoch 10/15, Batch 550/938, Loss: 0.2704 Epoch 10/15, Batch 600/938, Loss: 0.2702 Epoch 10/15, Batch 650/938, Loss: 0.2603 Epoch 10/15, Batch 700/938, Loss: 0.2695 Epoch 10/15, Batch 750/938, Loss: 0.2712 Epoch 10/15, Batch 800/938, Loss: 0.2698 Epoch 10/15, Batch 850/938, Loss: 0.2727 Epoch 10/15, Batch 900/938, Loss: 0.2674 Epoch 10/15 completed. Test Loss: 0.2039 Epoch 11/15, Batch 0/938, Loss: 0.2671 Epoch 11/15, Batch 50/938, Loss: 0.2717 Epoch 11/15, Batch 100/938, Loss: 0.2713 Epoch 11/15, Batch 150/938, Loss: 0.2608 Epoch 11/15, Batch 200/938, Loss: 0.2629 Epoch 11/15, Batch 250/938, Loss: 0.2695 Epoch 11/15, Batch 300/938, Loss: 0.2711 Epoch 11/15, Batch 350/938, Loss: 0.2676 Epoch 11/15, Batch 400/938, Loss: 0.2611 Epoch 11/15, Batch 450/938, Loss: 0.2552 Epoch 11/15, Batch 500/938, Loss: 0.2589 Epoch 11/15, Batch 550/938, Loss: 0.2625 Epoch 11/15, Batch 600/938, Loss: 0.2622 Epoch 11/15, Batch 650/938, Loss: 0.2644 Epoch 11/15, Batch 700/938, Loss: 0.2656 Epoch 11/15, Batch 750/938, Loss: 0.2547 Epoch 11/15, Batch 800/938, Loss: 0.2615 Epoch 11/15, Batch 850/938, Loss: 0.2606 Epoch 11/15, Batch 900/938, Loss: 0.2508 Epoch 11/15 completed. Test Loss: 0.1964 Epoch 12/15, Batch 0/938, Loss: 0.2681 Epoch 12/15, Batch 50/938, Loss: 0.2654 Epoch 12/15, Batch 100/938, Loss: 0.2444 Epoch 12/15, Batch 150/938, Loss: 0.2510 Epoch 12/15, Batch 200/938, Loss: 0.2604 Epoch 12/15, Batch 250/938, Loss: 0.2544 Epoch 12/15, Batch 300/938, Loss: 0.2586 Epoch 12/15, Batch 350/938, Loss: 0.2608 Epoch 12/15, Batch 400/938, Loss: 0.2521 Epoch 12/15, Batch 450/938, Loss: 0.2511 Epoch 12/15, Batch 500/938, Loss: 0.2401 Epoch 12/15, Batch 550/938, Loss: 0.2576 Epoch 12/15, Batch 600/938, Loss: 0.2604 Epoch 12/15, Batch 650/938, Loss: 0.2551 Epoch 12/15, Batch 700/938, Loss: 0.2525 Epoch 12/15, Batch 750/938, Loss: 0.2486 Epoch 12/15, Batch 800/938, Loss: 0.2555 Epoch 12/15, Batch 850/938, Loss: 0.2551 Epoch 12/15, Batch 900/938, Loss: 0.2485 Epoch 12/15 completed. Test Loss: 0.1897 Epoch 13/15, Batch 0/938, Loss: 0.2502 Epoch 13/15, Batch 50/938, Loss: 0.2573 Epoch 13/15, Batch 100/938, Loss: 0.2535 Epoch 13/15, Batch 150/938, Loss: 0.2527 Epoch 13/15, Batch 200/938, Loss: 0.2568 Epoch 13/15, Batch 250/938, Loss: 0.2567 Epoch 13/15, Batch 300/938, Loss: 0.2436 Epoch 13/15, Batch 350/938, Loss: 0.2546 Epoch 13/15, Batch 400/938, Loss: 0.2506 Epoch 13/15, Batch 450/938, Loss: 0.2578 Epoch 13/15, Batch 500/938, Loss: 0.2490 Epoch 13/15, Batch 550/938, Loss: 0.2463 Epoch 13/15, Batch 600/938, Loss: 0.2582 Epoch 13/15, Batch 650/938, Loss: 0.2472 Epoch 13/15, Batch 700/938, Loss: 0.2518 Epoch 13/15, Batch 750/938, Loss: 0.2475 Epoch 13/15, Batch 800/938, Loss: 0.2477 Epoch 13/15, Batch 850/938, Loss: 0.2563 Epoch 13/15, Batch 900/938, Loss: 0.2399 Epoch 13/15 completed. Test Loss: 0.1862 Epoch 14/15, Batch 0/938, Loss: 0.2519 Epoch 14/15, Batch 50/938, Loss: 0.2650 Epoch 14/15, Batch 100/938, Loss: 0.2491 Epoch 14/15, Batch 150/938, Loss: 0.2511 Epoch 14/15, Batch 200/938, Loss: 0.2407 Epoch 14/15, Batch 250/938, Loss: 0.2418 Epoch 14/15, Batch 300/938, Loss: 0.2521 Epoch 14/15, Batch 350/938, Loss: 0.2506 Epoch 14/15, Batch 400/938, Loss: 0.2489 Epoch 14/15, Batch 450/938, Loss: 0.2420 Epoch 14/15, Batch 500/938, Loss: 0.2482 Epoch 14/15, Batch 550/938, Loss: 0.2591 Epoch 14/15, Batch 600/938, Loss: 0.2382 Epoch 14/15, Batch 650/938, Loss: 0.2449 Epoch 14/15, Batch 700/938, Loss: 0.2440 Epoch 14/15, Batch 750/938, Loss: 0.2473 Epoch 14/15, Batch 800/938, Loss: 0.2421 Epoch 14/15, Batch 850/938, Loss: 0.2516 Epoch 14/15, Batch 900/938, Loss: 0.2431 Epoch 14/15 completed. Test Loss: 0.1843 Epoch 15/15, Batch 0/938, Loss: 0.2453 Epoch 15/15, Batch 50/938, Loss: 0.2596 Epoch 15/15, Batch 100/938, Loss: 0.2551 Epoch 15/15, Batch 150/938, Loss: 0.2485 Epoch 15/15, Batch 200/938, Loss: 0.2414 Epoch 15/15, Batch 250/938, Loss: 0.2397 Epoch 15/15, Batch 300/938, Loss: 0.2395 Epoch 15/15, Batch 350/938, Loss: 0.2410 Epoch 15/15, Batch 400/938, Loss: 0.2474 Epoch 15/15, Batch 450/938, Loss: 0.2586 Epoch 15/15, Batch 500/938, Loss: 0.2450 Epoch 15/15, Batch 550/938, Loss: 0.2593 Epoch 15/15, Batch 600/938, Loss: 0.2470 Epoch 15/15, Batch 650/938, Loss: 0.2482 Epoch 15/15, Batch 700/938, Loss: 0.2426 Epoch 15/15, Batch 750/938, Loss: 0.2484 Epoch 15/15, Batch 800/938, Loss: 0.2543 Epoch 15/15, Batch 850/938, Loss: 0.2503 Epoch 15/15, Batch 900/938, Loss: 0.2564 Epoch 15/15 completed. Test Loss: 0.1841
samples shape: (100, 28, 28, 3)
samples shape: (100, 28, 28, 3)
Question 4: Causal Transformer: Tokenized Images¶
Image Tokenization with Vector Quanization¶
Part (a) Image Quantization¶
Above, we implemented iGPT, which autoregressivly predicts raw pixels. Transformers have quadratic complexity in the sequence length which prevents this naive approach from scaling well to large images.
The space of natural images often contains very correlated information. This suggests we can learn a reduced representation. VQVAE is a method that does just that, learning to map images to a more compact discrete set of tokens. We will cover this method in more detail in future lectures. The only thing you need to know now is that we can learn an encoder (and corresponding decoder), which can extract a discrete representation from an image.
If you are curious, checkout the VQVAE paper to learn more: https://arxiv.org/abs/1711.00937 (we will cover this in a future lecture though!)
In this part, we provide a pre-trained VQVAE model, which consists of:
- encoder to tokenize the images
- the decoder to recover the image
- a token vocabulary of VQVAE_MODEL.n_embeddings
Below is the code for loading the VQ model. Note that VQVAE encoding process is lossy, so the decoded images will not be the exact same as the input. Some blurriness in the recovered image is to be expected. The docstrings of the relevant methods you will need for the VQVAE_MODEL are provided below for your convenience.
We will use 2 colored mnist datasets in this part. The first is the same dataset used in previous parts. The second, hads a colored digit on a differently colored background. We will call these datasets Colored MNIST and Colored MNIST v2. Note that the vqvae is trained per dataset.
You will provide these deliverables
- Use the provided encoder model to quantize the images then inspect the recovered images by applying the decoder for each of the two datasets
# @property
# def n_embeddings(self) -> int:
# """The size of the token vocabulary"""
#
# def quantize(self, x: np.ndarray) -> np.ndarray:
# """Quantize an image x.
#
# Args:
# x (np.ndarray, dtype=int): Image to quantize. shape=(batch_size, 28, 28, 3). Values in [0, 3].
#
# Returns:
# np.ndarray: Quantized image. shape=(batch_size, 7, 7). Values in [0, n_embeddings]
# """
#
# def decode(self, z_index: np.ndarray) -> np.ndarray:
# """Decode a quantized image.
#
# Args:
# z_index (np.ndarray, dtype=int): Quantized image. shape=(batch_size, 7, 7). Values in [0, n_embeddings].
#
# Returns:
# np.ndarray: Decoded image. shape=(batch_size, 28, 28, 3). Values in [0, 3].
# """
#
def q4_a(images, vqvae):
"""
images: (B, H, W, C), the images to pass through the encoder and decoder of the vqvae
vqvae: a vqvae model, trained on the relevant dataset
Returns
- a numpy array of size (2, H, W, C) of the decoded image
"""
print(vqvae.n_embeddings)
quantized_images = vqvae.quantize(images)
# print shape of quantized_images
print("quantinzed images:", quantized_images)
print("quantized_images shape: ", quantized_images.shape)
autoencoded_images = vqvae.decode(quantized_images)
return autoencoded_images
q4a_save_results(1, q4_a)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.061281282..1.1016651].
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
1024
quantinzed images: tensor([[[ 161, 198, 121, 218, 645, 171, 272],
[ 264, 191, 110, 193, 844, 334, 440],
[ 935, 730, 386, 1020, 657, 218, 260],
[ 145, 544, 730, 835, 702, 508, 96],
[1014, 134, 722, 906, 738, 697, 811],
[ 884, 268, 165, 94, 952, 821, 346],
[ 228, 647, 429, 722, 982, 872, 582]],
[[ 579, 228, 811, 219, 811, 569, 57],
[ 749, 699, 11, 305, 925, 830, 395],
[ 145, 593, 907, 422, 421, 533, 130],
[ 769, 429, 342, 201, 261, 309, 348],
[ 272, 609, 409, 884, 253, 19, 643],
[ 250, 740, 465, 253, 772, 264, 228],
[ 376, 534, 832, 18, 922, 134, 354]]])
quantized_images shape: torch.Size([2, 7, 7])
samples shape: (4, 28, 28, 3)
q4a_save_results(2, q4_a)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.08431109..1.1520311].
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
1024
quantinzed images: tensor([[[288, 75, 641, 75, 907, 907, 288],
[402, 964, 265, 636, 425, 427, 402],
[907, 993, 616, 504, 847, 718, 402],
[ 75, 896, 883, 274, 888, 288, 421],
[ 75, 641, 964, 419, 432, 421, 421],
[336, 451, 859, 904, 117, 402, 288],
[117, 694, 330, 336, 402, 288, 421]],
[[334, 779, 334, 226, 637, 779, 242],
[637, 950, 132, 914, 922, 802, 779],
[179, 253, 651, 167, 937, 713, 779],
[779, 675, 231, 132, 179, 939, 253],
[779, 928, 380, 435, 369, 136, 468],
[779, 928, 939, 859, 211, 625, 637],
[779, 334, 309, 435, 242, 468, 637]]])
quantized_images shape: torch.Size([2, 7, 7])
samples shape: (4, 28, 28, 3)
Part (b) Autoregressive Transformer on Colored Shapes and MNIST with Vector Quantization¶
We can use the VQVAE to tokenize an image dataset. This will result in a much smaller sequence length than the approach we tried in Question 3(b). For this part, train a transformer on the dataset tokenized by the VQVAE.
This is a simplified version of the approach used in VQGAN VQGAN -> Section 3.2: Learning the Composition of Images with Transformers (Again, we will cover this in more detail in a future lecture!)
Update the following hyperparameters:
- layers: 4 (we can train a bigger transformer now since less memory is used per input!)
- 30 epochs
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- 100 samples from the final trained model
def create_tokenized_data(data, image_shape, batch_size, vqvae):
H ,W ,C = image_shape
data_tokens = vqvae.quantize(data)
data_tokens = np.reshape(data_tokens, (data_tokens.shape[0], 7, 7))
data_flat = np.reshape(data_tokens, (data_tokens.shape[0], -1)) # (batch_size, 49)
dataset = torch.utils.data.TensorDataset(data_flat)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataloader
def q4_b(train_data, test_data, image_shape, dset_id, vqvae):
"""
train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
image_shape: (H, W, C), height, width, and # of channels of the image
dset_id: An identifying number of which dataset is given (1 or 2). Most likely
used to set different hyperparameters for different datasets
vqvae: a vqvae model, trained on dataset dset_id
Returns
- a (# of training iterations,) numpy array of train_losses evaluated every minibatch
- a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
- a numpy array of size (100, H, C, W) of samples with values in {0, 1, 2, 3}
"""
H, W, C = image_shape
# initialize hyperparameters
batch_size = 128
learning_rate = 1e-3
num_epochs = 30
d_model = 128
n_heads = 4
n_layers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# determine sequence length and vocab size
sequence_length = 7 * 7 + 1 # +1 for <bos> token
vocab_size = vqvae.n_embeddings
# create dataloaders
train_loader = create_tokenized_data(train_data, image_shape, batch_size, vqvae)
test_loader = create_tokenized_data(test_data, image_shape, batch_size, vqvae)
# test the dataloader
model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
train_losses, test_losses = train_igpt(model, train_loader, test_loader,
sequence_length, vocab_size, device,
num_epochs, learning_rate)
token_image_shape = (7,7,1)
samples, _ = generate_samples(model, sequence_length, vocab_size, token_image_shape, device)
# decode the samples
print("samples shape: ", samples.shape)
samples = samples.squeeze(-1)
print("samples shape: ", samples.shape)
samples = vqvae.decode(samples)
return train_losses, test_losses, samples
Results¶
Once you've implemented q4_b, execute the cells below to visualize and save your results
q4b_save_results(1, q4_b)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Epoch 1/30, Batch 0/469, Loss: 7.1077 Epoch 1/30, Batch 50/469, Loss: 6.9837 Epoch 1/30, Batch 100/469, Loss: 6.3427 Epoch 1/30, Batch 150/469, Loss: 5.9725 Epoch 1/30, Batch 200/469, Loss: 5.8723 Epoch 1/30, Batch 250/469, Loss: 5.5979 Epoch 1/30, Batch 300/469, Loss: 5.3181 Epoch 1/30, Batch 350/469, Loss: 5.1249 Epoch 1/30, Batch 400/469, Loss: 4.9939 Epoch 1/30, Batch 450/469, Loss: 4.9383 Epoch 1/30 completed. Test Loss: 4.8370 Epoch 2/30, Batch 0/469, Loss: 4.8403 Epoch 2/30, Batch 50/469, Loss: 4.7477 Epoch 2/30, Batch 100/469, Loss: 4.7367 Epoch 2/30, Batch 150/469, Loss: 4.7285 Epoch 2/30, Batch 200/469, Loss: 4.6845 Epoch 2/30, Batch 250/469, Loss: 4.6479 Epoch 2/30, Batch 300/469, Loss: 4.5854 Epoch 2/30, Batch 350/469, Loss: 4.5794 Epoch 2/30, Batch 400/469, Loss: 4.5112 Epoch 2/30, Batch 450/469, Loss: 4.4655 Epoch 2/30 completed. Test Loss: 4.3398 Epoch 3/30, Batch 0/469, Loss: 4.3931 Epoch 3/30, Batch 50/469, Loss: 4.3437 Epoch 3/30, Batch 100/469, Loss: 4.3613 Epoch 3/30, Batch 150/469, Loss: 4.3126 Epoch 3/30, Batch 200/469, Loss: 4.3593 Epoch 3/30, Batch 250/469, Loss: 4.2809 Epoch 3/30, Batch 300/469, Loss: 4.2827 Epoch 3/30, Batch 350/469, Loss: 4.2148 Epoch 3/30, Batch 400/469, Loss: 4.1774 Epoch 3/30, Batch 450/469, Loss: 4.2373 Epoch 3/30 completed. Test Loss: 4.0282 Epoch 4/30, Batch 0/469, Loss: 4.1408 Epoch 4/30, Batch 50/469, Loss: 4.0903 Epoch 4/30, Batch 100/469, Loss: 4.0367 Epoch 4/30, Batch 150/469, Loss: 4.0997 Epoch 4/30, Batch 200/469, Loss: 4.0682 Epoch 4/30, Batch 250/469, Loss: 4.0462 Epoch 4/30, Batch 300/469, Loss: 4.0812 Epoch 4/30, Batch 350/469, Loss: 4.0188 Epoch 4/30, Batch 400/469, Loss: 4.0250 Epoch 4/30, Batch 450/469, Loss: 3.9967 Epoch 4/30 completed. Test Loss: 3.8461 Epoch 5/30, Batch 0/469, Loss: 3.8849 Epoch 5/30, Batch 50/469, Loss: 3.9655 Epoch 5/30, Batch 100/469, Loss: 3.9511 Epoch 5/30, Batch 150/469, Loss: 3.9692 Epoch 5/30, Batch 200/469, Loss: 3.9061 Epoch 5/30, Batch 250/469, Loss: 3.9744 Epoch 5/30, Batch 300/469, Loss: 3.8557 Epoch 5/30, Batch 350/469, Loss: 3.8580 Epoch 5/30, Batch 400/469, Loss: 3.9167 Epoch 5/30, Batch 450/469, Loss: 3.8119 Epoch 5/30 completed. Test Loss: 3.7142 Epoch 6/30, Batch 0/469, Loss: 3.8242 Epoch 6/30, Batch 50/469, Loss: 3.8132 Epoch 6/30, Batch 100/469, Loss: 3.9260 Epoch 6/30, Batch 150/469, Loss: 3.8289 Epoch 6/30, Batch 200/469, Loss: 3.7896 Epoch 6/30, Batch 250/469, Loss: 3.7953 Epoch 6/30, Batch 300/469, Loss: 3.7992 Epoch 6/30, Batch 350/469, Loss: 3.8512 Epoch 6/30, Batch 400/469, Loss: 3.7463 Epoch 6/30, Batch 450/469, Loss: 3.7673 Epoch 6/30 completed. Test Loss: 3.6134 Epoch 7/30, Batch 0/469, Loss: 3.7481 Epoch 7/30, Batch 50/469, Loss: 3.7231 Epoch 7/30, Batch 100/469, Loss: 3.7169 Epoch 7/30, Batch 150/469, Loss: 3.7061 Epoch 7/30, Batch 200/469, Loss: 3.7337 Epoch 7/30, Batch 250/469, Loss: 3.6886 Epoch 7/30, Batch 300/469, Loss: 3.6998 Epoch 7/30, Batch 350/469, Loss: 3.7366 Epoch 7/30, Batch 400/469, Loss: 3.6939 Epoch 7/30, Batch 450/469, Loss: 3.6395 Epoch 7/30 completed. Test Loss: 3.5308 Epoch 8/30, Batch 0/469, Loss: 3.6519 Epoch 8/30, Batch 50/469, Loss: 3.6976 Epoch 8/30, Batch 100/469, Loss: 3.6389 Epoch 8/30, Batch 150/469, Loss: 3.7058 Epoch 8/30, Batch 200/469, Loss: 3.6129 Epoch 8/30, Batch 250/469, Loss: 3.7350 Epoch 8/30, Batch 300/469, Loss: 3.6363 Epoch 8/30, Batch 350/469, Loss: 3.6249 Epoch 8/30, Batch 400/469, Loss: 3.5829 Epoch 8/30, Batch 450/469, Loss: 3.6489 Epoch 8/30 completed. Test Loss: 3.4716 Epoch 9/30, Batch 0/469, Loss: 3.5837 Epoch 9/30, Batch 50/469, Loss: 3.6476 Epoch 9/30, Batch 100/469, Loss: 3.5756 Epoch 9/30, Batch 150/469, Loss: 3.6084 Epoch 9/30, Batch 200/469, Loss: 3.6124 Epoch 9/30, Batch 250/469, Loss: 3.6120 Epoch 9/30, Batch 300/469, Loss: 3.6157 Epoch 9/30, Batch 350/469, Loss: 3.6508 Epoch 9/30, Batch 400/469, Loss: 3.6172 Epoch 9/30, Batch 450/469, Loss: 3.5900 Epoch 9/30 completed. Test Loss: 3.4200 Epoch 10/30, Batch 0/469, Loss: 3.5248 Epoch 10/30, Batch 50/469, Loss: 3.5483 Epoch 10/30, Batch 100/469, Loss: 3.6200 Epoch 10/30, Batch 150/469, Loss: 3.5654 Epoch 10/30, Batch 200/469, Loss: 3.5237 Epoch 10/30, Batch 250/469, Loss: 3.5251 Epoch 10/30, Batch 300/469, Loss: 3.5507 Epoch 10/30, Batch 350/469, Loss: 3.5162 Epoch 10/30, Batch 400/469, Loss: 3.4542 Epoch 10/30, Batch 450/469, Loss: 3.5177 Epoch 10/30 completed. Test Loss: 3.3708 Epoch 11/30, Batch 0/469, Loss: 3.4839 Epoch 11/30, Batch 50/469, Loss: 3.5139 Epoch 11/30, Batch 100/469, Loss: 3.5196 Epoch 11/30, Batch 150/469, Loss: 3.4121 Epoch 11/30, Batch 200/469, Loss: 3.5378 Epoch 11/30, Batch 250/469, Loss: 3.5556 Epoch 11/30, Batch 300/469, Loss: 3.5336 Epoch 11/30, Batch 350/469, Loss: 3.5014 Epoch 11/30, Batch 400/469, Loss: 3.5298 Epoch 11/30, Batch 450/469, Loss: 3.5157 Epoch 11/30 completed. Test Loss: 3.3380 Epoch 12/30, Batch 0/469, Loss: 3.5097 Epoch 12/30, Batch 50/469, Loss: 3.4858 Epoch 12/30, Batch 100/469, Loss: 3.4831 Epoch 12/30, Batch 150/469, Loss: 3.4702 Epoch 12/30, Batch 200/469, Loss: 3.5245 Epoch 12/30, Batch 250/469, Loss: 3.4701 Epoch 12/30, Batch 300/469, Loss: 3.4306 Epoch 12/30, Batch 350/469, Loss: 3.5069 Epoch 12/30, Batch 400/469, Loss: 3.5154 Epoch 12/30, Batch 450/469, Loss: 3.4283 Epoch 12/30 completed. Test Loss: 3.3064 Epoch 13/30, Batch 0/469, Loss: 3.4527 Epoch 13/30, Batch 50/469, Loss: 3.4045 Epoch 13/30, Batch 100/469, Loss: 3.4684 Epoch 13/30, Batch 150/469, Loss: 3.4400 Epoch 13/30, Batch 200/469, Loss: 3.4590 Epoch 13/30, Batch 250/469, Loss: 3.3944 Epoch 13/30, Batch 300/469, Loss: 3.4304 Epoch 13/30, Batch 350/469, Loss: 3.4733 Epoch 13/30, Batch 400/469, Loss: 3.4694 Epoch 13/30, Batch 450/469, Loss: 3.4476 Epoch 13/30 completed. Test Loss: 3.2770 Epoch 14/30, Batch 0/469, Loss: 3.3833 Epoch 14/30, Batch 50/469, Loss: 3.4388 Epoch 14/30, Batch 100/469, Loss: 3.3825 Epoch 14/30, Batch 150/469, Loss: 3.4433 Epoch 14/30, Batch 200/469, Loss: 3.4666 Epoch 14/30, Batch 250/469, Loss: 3.4811 Epoch 14/30, Batch 300/469, Loss: 3.4056 Epoch 14/30, Batch 350/469, Loss: 3.4733 Epoch 14/30, Batch 400/469, Loss: 3.4005 Epoch 14/30, Batch 450/469, Loss: 3.4185 Epoch 14/30 completed. Test Loss: 3.2516 Epoch 15/30, Batch 0/469, Loss: 3.4239 Epoch 15/30, Batch 50/469, Loss: 3.3823 Epoch 15/30, Batch 100/469, Loss: 3.3771 Epoch 15/30, Batch 150/469, Loss: 3.3600 Epoch 15/30, Batch 200/469, Loss: 3.3744 Epoch 15/30, Batch 250/469, Loss: 3.4164 Epoch 15/30, Batch 300/469, Loss: 3.4302 Epoch 15/30, Batch 350/469, Loss: 3.4463 Epoch 15/30, Batch 400/469, Loss: 3.4218 Epoch 15/30, Batch 450/469, Loss: 3.3389 Epoch 15/30 completed. Test Loss: 3.2330 Epoch 16/30, Batch 0/469, Loss: 3.3724 Epoch 16/30, Batch 50/469, Loss: 3.3766 Epoch 16/30, Batch 100/469, Loss: 3.4000 Epoch 16/30, Batch 150/469, Loss: 3.4826 Epoch 16/30, Batch 200/469, Loss: 3.3904 Epoch 16/30, Batch 250/469, Loss: 3.3898 Epoch 16/30, Batch 300/469, Loss: 3.3781 Epoch 16/30, Batch 350/469, Loss: 3.4230 Epoch 16/30, Batch 400/469, Loss: 3.3931 Epoch 16/30, Batch 450/469, Loss: 3.3893 Epoch 16/30 completed. Test Loss: 3.2147 Epoch 17/30, Batch 0/469, Loss: 3.4125 Epoch 17/30, Batch 50/469, Loss: 3.3491 Epoch 17/30, Batch 100/469, Loss: 3.3792 Epoch 17/30, Batch 150/469, Loss: 3.3654 Epoch 17/30, Batch 200/469, Loss: 3.3137 Epoch 17/30, Batch 250/469, Loss: 3.3725 Epoch 17/30, Batch 300/469, Loss: 3.4090 Epoch 17/30, Batch 350/469, Loss: 3.3514 Epoch 17/30, Batch 400/469, Loss: 3.3518 Epoch 17/30, Batch 450/469, Loss: 3.3111 Epoch 17/30 completed. Test Loss: 3.1994 Epoch 18/30, Batch 0/469, Loss: 3.3238 Epoch 18/30, Batch 50/469, Loss: 3.3128 Epoch 18/30, Batch 100/469, Loss: 3.4004 Epoch 18/30, Batch 150/469, Loss: 3.3306 Epoch 18/30, Batch 200/469, Loss: 3.2622 Epoch 18/30, Batch 250/469, Loss: 3.3982 Epoch 18/30, Batch 300/469, Loss: 3.3378 Epoch 18/30, Batch 350/469, Loss: 3.3567 Epoch 18/30, Batch 400/469, Loss: 3.3094 Epoch 18/30, Batch 450/469, Loss: 3.3316 Epoch 18/30 completed. Test Loss: 3.1846 Epoch 19/30, Batch 0/469, Loss: 3.3406 Epoch 19/30, Batch 50/469, Loss: 3.2802 Epoch 19/30, Batch 100/469, Loss: 3.3789 Epoch 19/30, Batch 150/469, Loss: 3.3395 Epoch 19/30, Batch 200/469, Loss: 3.3386 Epoch 19/30, Batch 250/469, Loss: 3.3567 Epoch 19/30, Batch 300/469, Loss: 3.3761 Epoch 19/30, Batch 350/469, Loss: 3.3530 Epoch 19/30, Batch 400/469, Loss: 3.3722 Epoch 19/30, Batch 450/469, Loss: 3.2983 Epoch 19/30 completed. Test Loss: 3.1693 Epoch 20/30, Batch 0/469, Loss: 3.3158 Epoch 20/30, Batch 50/469, Loss: 3.2793 Epoch 20/30, Batch 100/469, Loss: 3.3011 Epoch 20/30, Batch 150/469, Loss: 3.3184 Epoch 20/30, Batch 200/469, Loss: 3.3578 Epoch 20/30, Batch 250/469, Loss: 3.2866 Epoch 20/30, Batch 300/469, Loss: 3.3581 Epoch 20/30, Batch 350/469, Loss: 3.3815 Epoch 20/30, Batch 400/469, Loss: 3.3537 Epoch 20/30, Batch 450/469, Loss: 3.3523 Epoch 20/30 completed. Test Loss: 3.1566 Epoch 21/30, Batch 0/469, Loss: 3.3200 Epoch 21/30, Batch 50/469, Loss: 3.2771 Epoch 21/30, Batch 100/469, Loss: 3.3150 Epoch 21/30, Batch 150/469, Loss: 3.2530 Epoch 21/30, Batch 200/469, Loss: 3.2572 Epoch 21/30, Batch 250/469, Loss: 3.3431 Epoch 21/30, Batch 300/469, Loss: 3.2836 Epoch 21/30, Batch 350/469, Loss: 3.2962 Epoch 21/30, Batch 400/469, Loss: 3.3519 Epoch 21/30, Batch 450/469, Loss: 3.3681 Epoch 21/30 completed. Test Loss: 3.1470 Epoch 22/30, Batch 0/469, Loss: 3.3074 Epoch 22/30, Batch 50/469, Loss: 3.2974 Epoch 22/30, Batch 100/469, Loss: 3.2479 Epoch 22/30, Batch 150/469, Loss: 3.2646 Epoch 22/30, Batch 200/469, Loss: 3.2972 Epoch 22/30, Batch 250/469, Loss: 3.2936 Epoch 22/30, Batch 300/469, Loss: 3.2462 Epoch 22/30, Batch 350/469, Loss: 3.3011 Epoch 22/30, Batch 400/469, Loss: 3.2563 Epoch 22/30, Batch 450/469, Loss: 3.3189 Epoch 22/30 completed. Test Loss: 3.1397 Epoch 23/30, Batch 0/469, Loss: 3.2839 Epoch 23/30, Batch 50/469, Loss: 3.2816 Epoch 23/30, Batch 100/469, Loss: 3.2867 Epoch 23/30, Batch 150/469, Loss: 3.2996 Epoch 23/30, Batch 200/469, Loss: 3.2771 Epoch 23/30, Batch 250/469, Loss: 3.3240 Epoch 23/30, Batch 300/469, Loss: 3.2692 Epoch 23/30, Batch 350/469, Loss: 3.2393 Epoch 23/30, Batch 400/469, Loss: 3.3506 Epoch 23/30, Batch 450/469, Loss: 3.2098 Epoch 23/30 completed. Test Loss: 3.1326 Epoch 24/30, Batch 0/469, Loss: 3.2436 Epoch 24/30, Batch 50/469, Loss: 3.2608 Epoch 24/30, Batch 100/469, Loss: 3.2284 Epoch 24/30, Batch 150/469, Loss: 3.3140 Epoch 24/30, Batch 200/469, Loss: 3.3108 Epoch 24/30, Batch 250/469, Loss: 3.2630 Epoch 24/30, Batch 300/469, Loss: 3.3307 Epoch 24/30, Batch 350/469, Loss: 3.2905 Epoch 24/30, Batch 400/469, Loss: 3.1866 Epoch 24/30, Batch 450/469, Loss: 3.2867 Epoch 24/30 completed. Test Loss: 3.1256 Epoch 25/30, Batch 0/469, Loss: 3.2576 Epoch 25/30, Batch 50/469, Loss: 3.2438 Epoch 25/30, Batch 100/469, Loss: 3.3221 Epoch 25/30, Batch 150/469, Loss: 3.2446 Epoch 25/30, Batch 200/469, Loss: 3.2813 Epoch 25/30, Batch 250/469, Loss: 3.2923 Epoch 25/30, Batch 300/469, Loss: 3.2792 Epoch 25/30, Batch 350/469, Loss: 3.2515 Epoch 25/30, Batch 400/469, Loss: 3.2860 Epoch 25/30, Batch 450/469, Loss: 3.2242 Epoch 25/30 completed. Test Loss: 3.1216 Epoch 26/30, Batch 0/469, Loss: 3.3549 Epoch 26/30, Batch 50/469, Loss: 3.2618 Epoch 26/30, Batch 100/469, Loss: 3.3040 Epoch 26/30, Batch 150/469, Loss: 3.1989 Epoch 26/30, Batch 200/469, Loss: 3.2737 Epoch 26/30, Batch 250/469, Loss: 3.2807 Epoch 26/30, Batch 300/469, Loss: 3.2195 Epoch 26/30, Batch 350/469, Loss: 3.3169 Epoch 26/30, Batch 400/469, Loss: 3.2727 Epoch 26/30, Batch 450/469, Loss: 3.2274 Epoch 26/30 completed. Test Loss: 3.1182 Epoch 27/30, Batch 0/469, Loss: 3.2861 Epoch 27/30, Batch 50/469, Loss: 3.1659 Epoch 27/30, Batch 100/469, Loss: 3.2290 Epoch 27/30, Batch 150/469, Loss: 3.2046 Epoch 27/30, Batch 200/469, Loss: 3.2605 Epoch 27/30, Batch 250/469, Loss: 3.2226 Epoch 27/30, Batch 300/469, Loss: 3.2486 Epoch 27/30, Batch 350/469, Loss: 3.2122 Epoch 27/30, Batch 400/469, Loss: 3.2031 Epoch 27/30, Batch 450/469, Loss: 3.2565 Epoch 27/30 completed. Test Loss: 3.1165 Epoch 28/30, Batch 0/469, Loss: 3.3085 Epoch 28/30, Batch 50/469, Loss: 3.2284 Epoch 28/30, Batch 100/469, Loss: 3.2412 Epoch 28/30, Batch 150/469, Loss: 3.2820 Epoch 28/30, Batch 200/469, Loss: 3.2468 Epoch 28/30, Batch 250/469, Loss: 3.2988 Epoch 28/30, Batch 300/469, Loss: 3.3192 Epoch 28/30, Batch 350/469, Loss: 3.2914 Epoch 28/30, Batch 400/469, Loss: 3.2607 Epoch 28/30, Batch 450/469, Loss: 3.2927 Epoch 28/30 completed. Test Loss: 3.1154 Epoch 29/30, Batch 0/469, Loss: 3.2820 Epoch 29/30, Batch 50/469, Loss: 3.2303 Epoch 29/30, Batch 100/469, Loss: 3.1826 Epoch 29/30, Batch 150/469, Loss: 3.2421 Epoch 29/30, Batch 200/469, Loss: 3.2241 Epoch 29/30, Batch 250/469, Loss: 3.3153 Epoch 29/30, Batch 300/469, Loss: 3.2545 Epoch 29/30, Batch 350/469, Loss: 3.1512 Epoch 29/30, Batch 400/469, Loss: 3.3258 Epoch 29/30, Batch 450/469, Loss: 3.2103 Epoch 29/30 completed. Test Loss: 3.1145 Epoch 30/30, Batch 0/469, Loss: 3.2814 Epoch 30/30, Batch 50/469, Loss: 3.2195 Epoch 30/30, Batch 100/469, Loss: 3.2526 Epoch 30/30, Batch 150/469, Loss: 3.2314 Epoch 30/30, Batch 200/469, Loss: 3.2903 Epoch 30/30, Batch 250/469, Loss: 3.2468 Epoch 30/30, Batch 300/469, Loss: 3.2556 Epoch 30/30, Batch 350/469, Loss: 3.2143 Epoch 30/30, Batch 400/469, Loss: 3.3093 Epoch 30/30, Batch 450/469, Loss: 3.1963 Epoch 30/30 completed. Test Loss: 3.1144 samples shape: (100, 7, 7, 1) samples shape: (100, 7, 7) Final Test Loss: 3.1144
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.13376339..1.283002].
samples shape: (100, 28, 28, 3)
q4b_save_results(2, q4_b)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Epoch 1/30, Batch 0/469, Loss: 7.1028 Epoch 1/30, Batch 50/469, Loss: 7.0020 Epoch 1/30, Batch 100/469, Loss: 6.5681 Epoch 1/30, Batch 150/469, Loss: 6.2534 Epoch 1/30, Batch 200/469, Loss: 5.9077 Epoch 1/30, Batch 250/469, Loss: 5.1703 Epoch 1/30, Batch 300/469, Loss: 4.6615 Epoch 1/30, Batch 350/469, Loss: 4.3008 Epoch 1/30, Batch 400/469, Loss: 4.2373 Epoch 1/30, Batch 450/469, Loss: 4.1927 Epoch 1/30 completed. Test Loss: 4.0851 Epoch 2/30, Batch 0/469, Loss: 4.1059 Epoch 2/30, Batch 50/469, Loss: 4.0950 Epoch 2/30, Batch 100/469, Loss: 4.1215 Epoch 2/30, Batch 150/469, Loss: 4.1579 Epoch 2/30, Batch 200/469, Loss: 3.9415 Epoch 2/30, Batch 250/469, Loss: 3.9819 Epoch 2/30, Batch 300/469, Loss: 3.9069 Epoch 2/30, Batch 350/469, Loss: 4.0260 Epoch 2/30, Batch 400/469, Loss: 3.9660 Epoch 2/30, Batch 450/469, Loss: 3.9541 Epoch 2/30 completed. Test Loss: 3.8060 Epoch 3/30, Batch 0/469, Loss: 3.8758 Epoch 3/30, Batch 50/469, Loss: 3.8279 Epoch 3/30, Batch 100/469, Loss: 3.7430 Epoch 3/30, Batch 150/469, Loss: 3.7442 Epoch 3/30, Batch 200/469, Loss: 3.7346 Epoch 3/30, Batch 250/469, Loss: 3.7924 Epoch 3/30, Batch 300/469, Loss: 3.7722 Epoch 3/30, Batch 350/469, Loss: 3.6915 Epoch 3/30, Batch 400/469, Loss: 3.5741 Epoch 3/30, Batch 450/469, Loss: 3.6316 Epoch 3/30 completed. Test Loss: 3.5549 Epoch 4/30, Batch 0/469, Loss: 3.6727 Epoch 4/30, Batch 50/469, Loss: 3.6121 Epoch 4/30, Batch 100/469, Loss: 3.6661 Epoch 4/30, Batch 150/469, Loss: 3.5980 Epoch 4/30, Batch 200/469, Loss: 3.6620 Epoch 4/30, Batch 250/469, Loss: 3.4974 Epoch 4/30, Batch 300/469, Loss: 3.4834 Epoch 4/30, Batch 350/469, Loss: 3.4665 Epoch 4/30, Batch 400/469, Loss: 3.4790 Epoch 4/30, Batch 450/469, Loss: 3.5781 Epoch 4/30 completed. Test Loss: 3.4005 Epoch 5/30, Batch 0/469, Loss: 3.4673 Epoch 5/30, Batch 50/469, Loss: 3.4440 Epoch 5/30, Batch 100/469, Loss: 3.5104 Epoch 5/30, Batch 150/469, Loss: 3.4669 Epoch 5/30, Batch 200/469, Loss: 3.4606 Epoch 5/30, Batch 250/469, Loss: 3.4286 Epoch 5/30, Batch 300/469, Loss: 3.3574 Epoch 5/30, Batch 350/469, Loss: 3.4290 Epoch 5/30, Batch 400/469, Loss: 3.4704 Epoch 5/30, Batch 450/469, Loss: 3.4133 Epoch 5/30 completed. Test Loss: 3.3082 Epoch 6/30, Batch 0/469, Loss: 3.3975 Epoch 6/30, Batch 50/469, Loss: 3.4547 Epoch 6/30, Batch 100/469, Loss: 3.3121 Epoch 6/30, Batch 150/469, Loss: 3.3135 Epoch 6/30, Batch 200/469, Loss: 3.3450 Epoch 6/30, Batch 250/469, Loss: 3.4129 Epoch 6/30, Batch 300/469, Loss: 3.4192 Epoch 6/30, Batch 350/469, Loss: 3.3552 Epoch 6/30, Batch 400/469, Loss: 3.2981 Epoch 6/30, Batch 450/469, Loss: 3.2434 Epoch 6/30 completed. Test Loss: 3.2377 Epoch 7/30, Batch 0/469, Loss: 3.3043 Epoch 7/30, Batch 50/469, Loss: 3.3149 Epoch 7/30, Batch 100/469, Loss: 3.3291 Epoch 7/30, Batch 150/469, Loss: 3.3469 Epoch 7/30, Batch 200/469, Loss: 3.3369 Epoch 7/30, Batch 250/469, Loss: 3.3725 Epoch 7/30, Batch 300/469, Loss: 3.3349 Epoch 7/30, Batch 350/469, Loss: 3.2290 Epoch 7/30, Batch 400/469, Loss: 3.2064 Epoch 7/30, Batch 450/469, Loss: 3.2823 Epoch 7/30 completed. Test Loss: 3.1858 Epoch 8/30, Batch 0/469, Loss: 3.2766 Epoch 8/30, Batch 50/469, Loss: 3.2861 Epoch 8/30, Batch 100/469, Loss: 3.2250 Epoch 8/30, Batch 150/469, Loss: 3.2699 Epoch 8/30, Batch 200/469, Loss: 3.2300 Epoch 8/30, Batch 250/469, Loss: 3.2301 Epoch 8/30, Batch 300/469, Loss: 3.2668 Epoch 8/30, Batch 350/469, Loss: 3.2865 Epoch 8/30, Batch 400/469, Loss: 3.1757 Epoch 8/30, Batch 450/469, Loss: 3.2293 Epoch 8/30 completed. Test Loss: 3.1492 Epoch 9/30, Batch 0/469, Loss: 3.2163 Epoch 9/30, Batch 50/469, Loss: 3.2104 Epoch 9/30, Batch 100/469, Loss: 3.2777 Epoch 9/30, Batch 150/469, Loss: 3.2420 Epoch 9/30, Batch 200/469, Loss: 3.2461 Epoch 9/30, Batch 250/469, Loss: 3.2567 Epoch 9/30, Batch 300/469, Loss: 3.1962 Epoch 9/30, Batch 350/469, Loss: 3.1944 Epoch 9/30, Batch 400/469, Loss: 3.1820 Epoch 9/30, Batch 450/469, Loss: 3.2126 Epoch 9/30 completed. Test Loss: 3.1186 Epoch 10/30, Batch 0/469, Loss: 3.1953 Epoch 10/30, Batch 50/469, Loss: 3.1867 Epoch 10/30, Batch 100/469, Loss: 3.1499 Epoch 10/30, Batch 150/469, Loss: 3.2222 Epoch 10/30, Batch 200/469, Loss: 3.1960 Epoch 10/30, Batch 250/469, Loss: 3.2125 Epoch 10/30, Batch 300/469, Loss: 3.1431 Epoch 10/30, Batch 350/469, Loss: 3.2460 Epoch 10/30, Batch 400/469, Loss: 3.1780 Epoch 10/30, Batch 450/469, Loss: 3.1525 Epoch 10/30 completed. Test Loss: 3.0947 Epoch 11/30, Batch 0/469, Loss: 3.2014 Epoch 11/30, Batch 50/469, Loss: 3.1994 Epoch 11/30, Batch 100/469, Loss: 3.1477 Epoch 11/30, Batch 150/469, Loss: 3.2226 Epoch 11/30, Batch 200/469, Loss: 3.2257 Epoch 11/30, Batch 250/469, Loss: 3.0542 Epoch 11/30, Batch 300/469, Loss: 3.1258 Epoch 11/30, Batch 350/469, Loss: 3.1137 Epoch 11/30, Batch 400/469, Loss: 3.1760 Epoch 11/30, Batch 450/469, Loss: 3.0894 Epoch 11/30 completed. Test Loss: 3.0735 Epoch 12/30, Batch 0/469, Loss: 3.0979 Epoch 12/30, Batch 50/469, Loss: 3.1667 Epoch 12/30, Batch 100/469, Loss: 3.1316 Epoch 12/30, Batch 150/469, Loss: 3.1550 Epoch 12/30, Batch 200/469, Loss: 3.0975 Epoch 12/30, Batch 250/469, Loss: 3.1202 Epoch 12/30, Batch 300/469, Loss: 3.0907 Epoch 12/30, Batch 350/469, Loss: 3.1181 Epoch 12/30, Batch 400/469, Loss: 3.0974 Epoch 12/30, Batch 450/469, Loss: 3.1560 Epoch 12/30 completed. Test Loss: 3.0577 Epoch 13/30, Batch 0/469, Loss: 3.0815 Epoch 13/30, Batch 50/469, Loss: 3.1552 Epoch 13/30, Batch 100/469, Loss: 3.0795 Epoch 13/30, Batch 150/469, Loss: 3.1172 Epoch 13/30, Batch 200/469, Loss: 3.1367 Epoch 13/30, Batch 250/469, Loss: 3.1360 Epoch 13/30, Batch 300/469, Loss: 3.0928 Epoch 13/30, Batch 350/469, Loss: 3.1047 Epoch 13/30, Batch 400/469, Loss: 3.2051 Epoch 13/30, Batch 450/469, Loss: 3.1227 Epoch 13/30 completed. Test Loss: 3.0422 Epoch 14/30, Batch 0/469, Loss: 3.0632 Epoch 14/30, Batch 50/469, Loss: 3.1038 Epoch 14/30, Batch 100/469, Loss: 3.1557 Epoch 14/30, Batch 150/469, Loss: 3.1844 Epoch 14/30, Batch 200/469, Loss: 3.1093 Epoch 14/30, Batch 250/469, Loss: 3.1344 Epoch 14/30, Batch 300/469, Loss: 3.0865 Epoch 14/30, Batch 350/469, Loss: 3.1240 Epoch 14/30, Batch 400/469, Loss: 3.1122 Epoch 14/30, Batch 450/469, Loss: 3.1129 Epoch 14/30 completed. Test Loss: 3.0321 Epoch 15/30, Batch 0/469, Loss: 3.0486 Epoch 15/30, Batch 50/469, Loss: 3.0854 Epoch 15/30, Batch 100/469, Loss: 3.0994 Epoch 15/30, Batch 150/469, Loss: 3.1074 Epoch 15/30, Batch 200/469, Loss: 3.0560 Epoch 15/30, Batch 250/469, Loss: 3.0814 Epoch 15/30, Batch 300/469, Loss: 3.1589 Epoch 15/30, Batch 350/469, Loss: 3.0332 Epoch 15/30, Batch 400/469, Loss: 3.0877 Epoch 15/30, Batch 450/469, Loss: 3.0549 Epoch 15/30 completed. Test Loss: 3.0209 Epoch 16/30, Batch 0/469, Loss: 3.1523 Epoch 16/30, Batch 50/469, Loss: 3.0517 Epoch 16/30, Batch 100/469, Loss: 3.1639 Epoch 16/30, Batch 150/469, Loss: 3.0564 Epoch 16/30, Batch 200/469, Loss: 3.0489 Epoch 16/30, Batch 250/469, Loss: 3.0244 Epoch 16/30, Batch 300/469, Loss: 3.0933 Epoch 16/30, Batch 350/469, Loss: 3.1153 Epoch 16/30, Batch 400/469, Loss: 3.0665 Epoch 16/30, Batch 450/469, Loss: 3.2082 Epoch 16/30 completed. Test Loss: 3.0118 Epoch 17/30, Batch 0/469, Loss: 3.0703 Epoch 17/30, Batch 50/469, Loss: 3.0688 Epoch 17/30, Batch 100/469, Loss: 3.0822 Epoch 17/30, Batch 150/469, Loss: 3.0145 Epoch 17/30, Batch 200/469, Loss: 3.0619 Epoch 17/30, Batch 250/469, Loss: 3.1068 Epoch 17/30, Batch 300/469, Loss: 3.0871 Epoch 17/30, Batch 350/469, Loss: 3.1040 Epoch 17/30, Batch 400/469, Loss: 3.0474 Epoch 17/30, Batch 450/469, Loss: 3.1062 Epoch 17/30 completed. Test Loss: 2.9991 Epoch 18/30, Batch 0/469, Loss: 3.0139 Epoch 18/30, Batch 50/469, Loss: 2.9685 Epoch 18/30, Batch 100/469, Loss: 2.9834 Epoch 18/30, Batch 150/469, Loss: 3.0927 Epoch 18/30, Batch 200/469, Loss: 3.0285 Epoch 18/30, Batch 250/469, Loss: 3.1719 Epoch 18/30, Batch 300/469, Loss: 3.0437 Epoch 18/30, Batch 350/469, Loss: 3.0007 Epoch 18/30, Batch 400/469, Loss: 3.0559 Epoch 18/30, Batch 450/469, Loss: 3.0327 Epoch 18/30 completed. Test Loss: 2.9894 Epoch 19/30, Batch 0/469, Loss: 3.0575 Epoch 19/30, Batch 50/469, Loss: 3.1054 Epoch 19/30, Batch 100/469, Loss: 2.9840 Epoch 19/30, Batch 150/469, Loss: 3.1167 Epoch 19/30, Batch 200/469, Loss: 3.0414 Epoch 19/30, Batch 250/469, Loss: 3.0328 Epoch 19/30, Batch 300/469, Loss: 3.1417 Epoch 19/30, Batch 350/469, Loss: 3.0676 Epoch 19/30, Batch 400/469, Loss: 3.0764 Epoch 19/30, Batch 450/469, Loss: 3.0550 Epoch 19/30 completed. Test Loss: 2.9796 Epoch 20/30, Batch 0/469, Loss: 3.0773 Epoch 20/30, Batch 50/469, Loss: 3.0147 Epoch 20/30, Batch 100/469, Loss: 2.9797 Epoch 20/30, Batch 150/469, Loss: 3.0323 Epoch 20/30, Batch 200/469, Loss: 3.0504 Epoch 20/30, Batch 250/469, Loss: 3.1142 Epoch 20/30, Batch 300/469, Loss: 2.9954 Epoch 20/30, Batch 350/469, Loss: 3.0329 Epoch 20/30, Batch 400/469, Loss: 3.0967 Epoch 20/30, Batch 450/469, Loss: 2.9793 Epoch 20/30 completed. Test Loss: 2.9761 Epoch 21/30, Batch 0/469, Loss: 3.0671 Epoch 21/30, Batch 50/469, Loss: 3.0186 Epoch 21/30, Batch 100/469, Loss: 3.0481 Epoch 21/30, Batch 150/469, Loss: 3.0473 Epoch 21/30, Batch 200/469, Loss: 2.9966 Epoch 21/30, Batch 250/469, Loss: 2.9813 Epoch 21/30, Batch 300/469, Loss: 3.0101 Epoch 21/30, Batch 350/469, Loss: 3.0607 Epoch 21/30, Batch 400/469, Loss: 3.0566 Epoch 21/30, Batch 450/469, Loss: 3.0528 Epoch 21/30 completed. Test Loss: 2.9700 Epoch 22/30, Batch 0/469, Loss: 3.0110 Epoch 22/30, Batch 50/469, Loss: 2.9996 Epoch 22/30, Batch 100/469, Loss: 3.0245 Epoch 22/30, Batch 150/469, Loss: 2.9715 Epoch 22/30, Batch 200/469, Loss: 3.0241 Epoch 22/30, Batch 250/469, Loss: 2.9922 Epoch 22/30, Batch 300/469, Loss: 2.9814 Epoch 22/30, Batch 350/469, Loss: 3.1051 Epoch 22/30, Batch 400/469, Loss: 3.0265 Epoch 22/30, Batch 450/469, Loss: 3.0517 Epoch 22/30 completed. Test Loss: 2.9665 Epoch 23/30, Batch 0/469, Loss: 3.0043 Epoch 23/30, Batch 50/469, Loss: 2.9430 Epoch 23/30, Batch 100/469, Loss: 2.9871 Epoch 23/30, Batch 150/469, Loss: 3.0060 Epoch 23/30, Batch 200/469, Loss: 2.9903 Epoch 23/30, Batch 250/469, Loss: 3.0533 Epoch 23/30, Batch 300/469, Loss: 2.9526 Epoch 23/30, Batch 350/469, Loss: 3.0467 Epoch 23/30, Batch 400/469, Loss: 3.0433 Epoch 23/30, Batch 450/469, Loss: 2.9706 Epoch 23/30 completed. Test Loss: 2.9617 Epoch 24/30, Batch 0/469, Loss: 2.9814 Epoch 24/30, Batch 50/469, Loss: 2.9874 Epoch 24/30, Batch 100/469, Loss: 2.9733 Epoch 24/30, Batch 150/469, Loss: 3.0548 Epoch 24/30, Batch 200/469, Loss: 3.0537 Epoch 24/30, Batch 250/469, Loss: 2.9989 Epoch 24/30, Batch 300/469, Loss: 3.0660 Epoch 24/30, Batch 350/469, Loss: 2.9604 Epoch 24/30, Batch 400/469, Loss: 3.1869 Epoch 24/30, Batch 450/469, Loss: 2.9720 Epoch 24/30 completed. Test Loss: 2.9579 Epoch 25/30, Batch 0/469, Loss: 3.0287 Epoch 25/30, Batch 50/469, Loss: 3.0835 Epoch 25/30, Batch 100/469, Loss: 3.0299 Epoch 25/30, Batch 150/469, Loss: 2.9867 Epoch 25/30, Batch 200/469, Loss: 3.0570 Epoch 25/30, Batch 250/469, Loss: 3.0093 Epoch 25/30, Batch 300/469, Loss: 3.0596 Epoch 25/30, Batch 350/469, Loss: 3.0245 Epoch 25/30, Batch 400/469, Loss: 3.0338 Epoch 25/30, Batch 450/469, Loss: 2.9632 Epoch 25/30 completed. Test Loss: 2.9569 Epoch 26/30, Batch 0/469, Loss: 2.9927 Epoch 26/30, Batch 50/469, Loss: 2.9575 Epoch 26/30, Batch 100/469, Loss: 3.0357 Epoch 26/30, Batch 150/469, Loss: 2.9555 Epoch 26/30, Batch 200/469, Loss: 2.9541 Epoch 26/30, Batch 250/469, Loss: 3.0262 Epoch 26/30, Batch 300/469, Loss: 3.1024 Epoch 26/30, Batch 350/469, Loss: 2.9882 Epoch 26/30, Batch 400/469, Loss: 3.0074 Epoch 26/30, Batch 450/469, Loss: 2.9454 Epoch 26/30 completed. Test Loss: 2.9539 Epoch 27/30, Batch 0/469, Loss: 3.0561 Epoch 27/30, Batch 50/469, Loss: 3.0346 Epoch 27/30, Batch 100/469, Loss: 2.9425 Epoch 27/30, Batch 150/469, Loss: 3.0430 Epoch 27/30, Batch 200/469, Loss: 3.0076 Epoch 27/30, Batch 250/469, Loss: 2.9893 Epoch 27/30, Batch 300/469, Loss: 3.0603 Epoch 27/30, Batch 350/469, Loss: 3.0235 Epoch 27/30, Batch 400/469, Loss: 3.0288 Epoch 27/30, Batch 450/469, Loss: 3.0160 Epoch 27/30 completed. Test Loss: 2.9521 Epoch 28/30, Batch 0/469, Loss: 2.9329 Epoch 28/30, Batch 50/469, Loss: 3.0001 Epoch 28/30, Batch 100/469, Loss: 2.9583 Epoch 28/30, Batch 150/469, Loss: 3.0459 Epoch 28/30, Batch 200/469, Loss: 3.0115 Epoch 28/30, Batch 250/469, Loss: 2.9163 Epoch 28/30, Batch 300/469, Loss: 3.0544 Epoch 28/30, Batch 350/469, Loss: 2.9707 Epoch 28/30, Batch 400/469, Loss: 2.9469 Epoch 28/30, Batch 450/469, Loss: 3.0173 Epoch 28/30 completed. Test Loss: 2.9519 Epoch 29/30, Batch 0/469, Loss: 3.0644 Epoch 29/30, Batch 50/469, Loss: 2.9810 Epoch 29/30, Batch 100/469, Loss: 2.9787 Epoch 29/30, Batch 150/469, Loss: 2.9147 Epoch 29/30, Batch 200/469, Loss: 3.0088 Epoch 29/30, Batch 250/469, Loss: 3.0743 Epoch 29/30, Batch 300/469, Loss: 3.0240 Epoch 29/30, Batch 350/469, Loss: 3.0339 Epoch 29/30, Batch 400/469, Loss: 3.0308 Epoch 29/30, Batch 450/469, Loss: 2.9863 Epoch 29/30 completed. Test Loss: 2.9516 Epoch 30/30, Batch 0/469, Loss: 3.0039 Epoch 30/30, Batch 50/469, Loss: 2.9561 Epoch 30/30, Batch 100/469, Loss: 3.0510 Epoch 30/30, Batch 150/469, Loss: 2.9257 Epoch 30/30, Batch 200/469, Loss: 2.9440 Epoch 30/30, Batch 250/469, Loss: 2.9676 Epoch 30/30, Batch 300/469, Loss: 3.1028 Epoch 30/30, Batch 350/469, Loss: 2.9682 Epoch 30/30, Batch 400/469, Loss: 2.9149 Epoch 30/30, Batch 450/469, Loss: 2.9851 Epoch 30/30 completed. Test Loss: 2.9514 samples shape: (100, 7, 7, 1) samples shape: (100, 7, 7) Final Test Loss: 2.9514
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.23252869..1.2633749].
samples shape: (100, 28, 28, 3)
Question 5: Causal Transformer: Text¶
Now lets consider text! You are probably already fimilar with autoregressive transformers for text, now more commonly known as Large Language Modesl (LLMs). We will now implement a simplified version.
We will be detailing with a small poetry dataset. See some of the data below.
data = visualize_q5_data()
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
Sample 1
E.E. Cummings, [as freedom is a breakfastfood] from Complete Poems 1904-1962, edited by George J. Firmage. Copyright 1926, 1954, 1991 by the Trustees for the E.E. Cummings Trust. Copyright 1985 by George James Firmage. Reprinted with the permission of Liveright Publishing Corporation.
--------------------------------------------------------------------------------
Sample 2
The moon has left the sky, love,
The stars are hiding now,
And frowning on the world, love,
Night bares her sable brow.
The snow is on the ground, love,
And cold and keen the air is.
Im singing here to you, love;
Youre dreaming there in Paris.
But this is Natures law, love,
Though just it may not seem,
That men should wake to sing, love;
While maidens sleep and dream.
Them care may not molest, love,
Nor stir them from their slumbers,
Though midnight find the swain, love.
Still halting oer his numbers.
I watch the rosy dawn, love,
Come stealing up the east,
While all things round rejoice, love,
That Night her reign has ceased.
The lark will soon be heard, love,
And on his way be winging;
When Natures poets, wake, love,
Why should a man be singing?
--------------------------------------------------------------------------------
Sample 3
My sweetest Lesbia, let us live and love,
And though the sager sort our deeds reprove,
Let us not weigh them. Heavens great lamps do dive
Into their west, and straight again revive,
But soon as once set is our little light,
Then must we sleep one ever-during night.
If all would lead their lives in love like me,
Then bloody swords and armor should not be;
No drum nor trumpet peaceful sleeps should move,
Unless alarm came from the camp of love.
But fools do live, and waste their little light,
And seek with pain their ever-during night.
When timely death my life and fortune ends,
Let not my hearse be vexed with mourning friends,
But let all lovers, rich in triumph, come
And with sweet pastimes grace my happy tomb;
And Lesbia, close up thou my little light,
And crown with love my ever-during night.
--------------------------------------------------------------------------------
Sample 4
When, in disgrace with fortune and mens eyes,
I all alone beweep my outcast state,
And trouble deaf heaven with my bootless cries,
And look upon myself and curse my fate,
Wishing me like to one more rich in hope,
Featured like him, like him with friends possessed,
Desiring this mans art and that mans scope,
With what I most enjoy contented least;
Yet in these thoughts myself almost despising,
Haply I think on thee, and then my state,
(Like to the lark at break of day arising
From sullen earth) sings hymns at heavens gate;
For thy sweet love remembered such wealth brings
That then I scorn to change my state with kings.
--------------------------------------------------------------------------------
Part (a) Modeling Text¶
Train a transformer on the poetry dataset.
Data Preprocessing:
- We will use a simple method to tokenize the data. We will convert each unique character into a token. (Current LLMs use more sophisticated tokenizers, most commonly, byte-pair encoding)
- Previously we have leveraged a <bos> as part of the model, just like iGPT. For text, we may not always sample a sequence that starts at the beginning. Instead, we will add the <bos> token to the beginning of every sequence in the dataset, and remove the <bos> token from the model.
- Another problem is that the model must know when to stop sampling. This is done by appending an <eos>, or end of sequence token at the end of every sequence in the dataset.
- We can now convert the sequence into subsequences of size context_length, for training!
We recommend the following hyperparameters:
- Sequence length: 128
- 5 epochs
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- Provide 5 unconditional samples of 128 characters showcasing the model text generation capabilities (text samples should stop after <eos>. Text after <eos> can be removed in post processing)
import torch
from torch.utils.data import Dataset
class Tokenizer:
def __init__(self, texts):
# Create a set of all unique characters across all texts
all_chars = set()
for text in texts:
all_chars.update(text)
# Sort characters for consistent mapping
all_chars = sorted(all_chars)
self.char_to_id = {char: i + 2 for i, char in enumerate(all_chars)}
self.id_to_char = {i + 2: char for i, char in enumerate(all_chars)}
self.bos_token = 0
self.eos_token = 1
self.char_to_id['<bos>'] = self.bos_token
self.char_to_id['<eos>'] = self.eos_token
self.id_to_char[self.bos_token] = '<bos>'
self.id_to_char[self.eos_token] = '<eos>'
self.vocab_size = len(self.char_to_id)
def encode(self, text):
tokens = [self.char_to_id[char] for char in text]
tokens.insert(0, self.bos_token)
tokens.append(self.eos_token)
return torch.tensor(tokens)
def decode(self, tokens):
chars = [self.id_to_char[token] for token in tokens if token != self.bos_token and token != self.eos_token]
# remove the special tokens
chars = [char for char in chars if char != '<bos>' and char != '<eos>']
return ''.join(chars)
class TextData(Dataset):
def __init__(self, texts, tokenizer, sequence_length):
self.tokenizer = tokenizer
self.sequence_length = sequence_length
# Tokenize all texts with BOS and EOS tokens
self.sequences = []
for text in texts:
# Encode the text (this adds BOS and EOS)
tokens = tokenizer.encode(text)
stride = 1
if len(tokens) > sequence_length:
for i in range(0, len(tokens) - sequence_length + 1, stride):
self.sequences.append(tokens[i:i + sequence_length])
# else:
# # Drop the shorter sequences
# padded = torch.full((sequence_length,), self.tokenizer.eos_token, dtype=tokens.dtype)
# padded[:len(tokens)] = tokens
# self.sequences.append(padded)
def __len__(self):
return len(self.sequences)
def __getitem__(self, index):
return self.sequences[index]
from torch.utils.data import Dataset, DataLoader
# Then in your function:
def create_text_data_loader(texts, sequence_length, tokenizer, batch_size):
dataset = TextData(texts, tokenizer, sequence_length)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# text model architecture version
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.0):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
"""
q: (batch_size, n_heads, seq_len, head_size)
k: (batch_size, n_heads, seq_len, head_size)
v: (batch_size, n_heads, seq_len, head_size)
"""
d_k = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k) # (batch_size, n_heads, seq_len, seq_len)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1) # (batch_size, n_heads, seq_len, seq_len)
attention_weights = self.dropout(attention_weights) # (batch_size, n_heads, seq_len, seq_len)
output = torch.matmul(attention_weights, v) # (batch_size, n_heads, seq_len, head_size)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0, cache=False):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_size = d_model // n_heads
self.use_cache = cache
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.attention = ScaledDotProductAttention(dropout=dropout)
self.cached_k = None
self.cached_v = None
def split_heads(self, x):
"""
x: (batch_size, seq_len, d_model)
"""
batch_size, seq_len, d_model = x.shape
return x.view(batch_size, seq_len, self.n_heads, self.head_size).transpose(1, 2) # (batch_size, n_heads, seq_len, head_size)
def combine_heads(self, x):
"""
x: (batch_size, n_heads, seq_len, head_size)
"""
batch_size, n_heads, seq_len, head_size = x.shape
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) # (batch_size, seq_len, d_model)
def forward(self, x, mask=None, use_cache=False, past_key_values=None):
batch_size, seq_len, d_model = x.shape
if past_key_values is not None:
self.cached_k, self.cached_v = past_key_values
q = self.W_q(x) # (batch_size, seq_len, d_model)
k = self.W_k(x)
v = self.W_v(x)
q = self.split_heads(q) # (batch_size, n_heads, seq_len, head_size)
k = self.split_heads(k)
v = self.split_heads(v)
# Use KV cache if enabled
if use_cache and self.cached_k is not None and self.cached_v is not None:
# Concatenate current k, v with cached k, v
k = torch.cat([self.cached_k, k], dim=2)
v = torch.cat([self.cached_v, v], dim=2)
self.cached_k = k
self.cached_v = v
# Create causal mask if needed
if mask is None:
# If using cache, adjust mask to account for the full sequence length
full_seq_len = k.size(2)
# For cached version, we need to adjust the mask to allow attention to all past tokens
if use_cache and self.cached_k is not None:
# Create a mask where current tokens can attend to all previous tokens
# Current sequence position is at seq_len
seq_position = seq_len
# Create a mask that allows each token to see itself and all previous tokens
mask = torch.ones(seq_len, full_seq_len).to(x.device)
# Make it causal by setting future positions to 0
mask[:, seq_position:] = 0
else:
# Standard causal mask for the full sequence
mask = torch.tril(torch.ones(full_seq_len, full_seq_len)).to(x.device)
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
# Use the attention module directly
output = self.attention(q, k, v, mask) # (batch_size, n_heads, seq_len, head_size)
# Combine heads
output = self.combine_heads(output) # (batch_size, seq_len, d_model)
past_key_values = (k, v)
if use_cache:
return self.dropout(self.out(output)) , past_key_values
else:
return self.dropout(self.out(output))
def clear_cache(self):
self.cached_k = None
self.cached_v = None
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1, use_cache=False):
super().__init__()
self.masked_mha = MultiHeadAttention(d_model, n_heads, dropout, cache=use_cache)
self.layer_norm1 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
nn.Dropout(dropout)
)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, use_cache=False, past_key_values=None):
# Self-attention with residual connection and layer normalization
residual = x
x = self.layer_norm1(x) # Pre-norm architecture
if use_cache and past_key_values is not None:
x, past_key_values = self.masked_mha(x, use_cache=use_cache, past_key_values=past_key_values)
else:
x = self.masked_mha(x)
x = residual + x # Residual connection
# Feed forward with residual connection and layer normalization
residual = x
x = self.layer_norm2(x) # Pre-norm architecture
x = self.feed_forward(x)
x = residual + x # Residual connection
if use_cache:
return x , past_key_values
else:
return x
def clear_cache(self):
self.masked_mha.clear_cache()
class iGPT(nn.Module):
def __init__(self, vocab_size, context_length, d_model, n_heads, n_layers, dropout=0.1, use_cache=False):
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.dropout = dropout
self.use_cache = use_cache
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, d_model)
# Positional embedding (learned, as per iGPT specs)
self.position_embedding = nn.Embedding(context_length, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# Stack of decoder layers
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, dropout, use_cache=use_cache)
for _ in range(n_layers)
])
# Final layer norm
self.layer_norm = nn.LayerNorm(d_model)
# Output projection
self.output_projection = nn.Linear(d_model, vocab_size)
def forward(self, x, past_key_values=None, use_cache=False):
# x shape: (batch_size, seq_len)
batch_size, seq_len = x.shape
device = x.device
# Create position indices
positions = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
# Get embeddings
token_emb = self.token_embedding(x) # (batch_size, seq_len, d_model)
pos_emb = self.position_embedding(positions) # (batch_size, seq_len, d_model)
# Combine embeddings
x = token_emb + pos_emb # (batch_size, seq_len, d_model)
x = self.dropout(x)
# Apply decoder layers
past_key_values = None
for layer in self.decoder_layers:
if use_cache:
x, past_key_values = layer(x, use_cache=use_cache, past_key_values=past_key_values)
else:
x = layer(x)
# Apply final layer norm
x = self.layer_norm(x) # (batch_size, seq_len, d_model)
# Project to vocabulary
logits = self.output_projection(x) # (batch_size, seq_len, vocab_size)
if use_cache:
return logits, past_key_values
else:
return logits
def clear_cache(self):
for layer in self.decoder_layers:
layer.clear_cache()
import math
def create_dataset(data, image_shape, batch_size):
"""
Converts image data to token sequences and creates PyTorch DataLoader.
Args:
data: A (n_samples, H, W, C) uint8 numpy array of images
image_shape: (H, W, C) tuple specifying image dimensions
batch_size: Batch size for DataLoader
Returns:
DataLoader object with tokenized image sequences
"""
H, W, C = image_shape
# Convert RGB pixels to single tokens (4 values per channel = 64 possible values)
# Shape: (n_samples, H, W, C) -> (n_samples, H, W)
if C == 3:
# Convert RGB values to a single token: r*16 + g*4 + b
# Each channel has values in {0,1,2,3}, so we can encode as a single number 0-63
data_tokens = (data[:,:,:,0] * 16 + data[:,:,:,1] * 4 + data[:,:,:,2])
else:
# For grayscale, just use the values directly
data_tokens = data.reshape(-1, H, W)
# Flatten spatial dimensions to create sequences
# Shape: (n_samples, H, W) -> (n_samples, H*W)
data_flat = data_tokens.reshape(-1, H * W)
# Convert to PyTorch tensors
dataset = torch.utils.data.TensorDataset(torch.tensor(data_flat, dtype=torch.long))
# Create data loader
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
def evaluate_model(model, data_loader, sequence_length, vocab_size, device):
"""
Evaluates model performance on a dataset.
Args:
model: The iGPT model
data_loader: DataLoader containing tokenized images
sequence_length: Length of token sequences including <bos>
vocab_size: Size of vocabulary
device: Device to run evaluation on
Returns:
Average loss (negative log-likelihood) per dimension
"""
model.eval()
total_loss = 0
total_samples = 0
with torch.no_grad():
for data in data_loader:
data = data.to(device) # Shape: (batch_size, sequence_length-1)
batch_size = data.size(0)
input_seq = data[:, :-1]
targets = data[:, 1:]
# Forward pass
logits = model(input_seq) # Remove last position's prediction
# Compute loss
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1), reduction='sum')
total_loss += loss.item()
total_samples += batch_size * (sequence_length - 1)
return total_loss / total_samples
def train_igpt(model, train_loader, test_loader, sequence_length, vocab_size,
device, num_epochs, learning_rate):
"""
Trains the iGPT model.
Args:
model: The iGPT model to train
train_loader: DataLoader for training data
test_loader: DataLoader for test data
sequence_length: Length of token sequences including <bos>
vocab_size: Size of vocabulary
device: Device to train on
num_epochs: Number of training epochs
learning_rate: Initial learning rate
Returns:
train_losses: Array of training losses per minibatch
test_losses: Array of test losses per epoch
"""
# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Learning rate scheduler with warmup and cosine decay
warmup_steps = 100
total_steps = len(train_loader) * num_epochs
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Initialize arrays to store losses
train_losses = []
test_losses = [evaluate_model(model, test_loader, sequence_length, vocab_size, device)]
# Training loop
for epoch in range(num_epochs):
model.train()
epoch_losses = []
batch_idx = 0
for data in train_loader:
batch_idx += 1
data = data.to(device) # Shape: (batch_size, sequence_length)
# Shape: (batch_size, sequence_length-1)
input_seq = data[:, :-1]
targets = data[:, 1:]
# Forward pass
logits = model(input_seq)
targets = targets.to(device)
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
# print(f"loss: {loss.item():.4f}")
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# Record loss
train_losses.append(loss.item())
if batch_idx % 50 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
# Evaluate on test set after each epoch
test_loss = evaluate_model(model, test_loader, sequence_length, vocab_size, device)
test_losses.append(test_loss)
print(f"Epoch {epoch+1}/{num_epochs} completed. Test Loss: {test_loss:.4f}")
return np.array(train_losses), np.array(test_losses)
def generate_text_samples(model, tokenizer, max_length, device, num_samples=10, temperature=1.0, use_cache=False):
"""
Generates text samples from the trained model.
Args:
model: The trained language model
tokenizer: The tokenizer used to encode/decode text
max_length: Maximum length of the generated sequence (including BOS/EOS)
device: Device to run generation on
num_samples: Number of samples to generate
temperature: Controls randomness (lower = more deterministic)
use_cache: Whether to use caching for faster sampling
Returns:
List of generated text samples and a list of generation times
"""
model.eval()
samples = []
import time
time_list = []
with torch.no_grad():
for _ in range(num_samples):
start_time = time.time()
# Start with just the BOS token
current_seq = torch.tensor([[tokenizer.bos_token]], dtype=torch.long, device=device)
# Cache for key-value pairs if using caching
past_key_values = None
# Autoregressive generation - one token at a time
for _ in range(max_length - 1): # -1 because we already have BOS
if use_cache and past_key_values is not None:
# Only need to process the new token with cached key-values
logits, past_key_values = model(
current_seq[:, -1:],
past_key_values=past_key_values,
use_cache=True
)
logits = logits[:, -1, :] # Get prediction for current position
else:
# Process the entire sequence
if use_cache:
logits, past_key_values = model(current_seq, use_cache=True)
logits = logits[:, -1, :] # Get prediction for current position
else:
logits = model(current_seq)
logits = logits[:, -1, :] # Get prediction for last position
# Apply temperature
if temperature != 1.0:
logits = logits / temperature
# Sample from the probability distribution
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
# Append new token to sequence
current_seq = torch.cat([current_seq, next_token], dim=1)
# Stop if EOS token is generated
if next_token.item() == tokenizer.eos_token:
break
# Decode the generated sequence
generated_tokens = current_seq[0].cpu().tolist()
generated_text = tokenizer.decode(generated_tokens)
samples.append(generated_text)
end_time = time.time()
time_list.append(end_time - start_time)
return samples, np.array(time_list)
import torch.utils.data as data
def q5_a(train_text, test_text):
"""
train_text: list[str] Train text sequences.
test_text: list[str] Test text sequences.
Returns
- a (# of training iterations,) numpy array of train_losses evaluated every minibatch
- a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
- a list of 5 (str), 5 generated samples from the model.
"""
sequence_length = 128
epochs = 5
learning_rate = 1e-3
d_model = 128
n_heads = 4
n_layers = 6
batch_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = Tokenizer(train_text)
vocab_size = tokenizer.vocab_size
train_loader = create_text_data_loader(train_text, sequence_length, tokenizer, batch_size)
test_loader = create_text_data_loader(test_text, sequence_length, tokenizer, batch_size)
model = iGPT(vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
train_losses, test_losses = train_igpt(model, train_loader, test_loader, sequence_length, vocab_size, device, epochs, learning_rate)
text_samples, _ = generate_text_samples(model, tokenizer, sequence_length, device, num_samples=5)
return train_losses, test_losses, text_samples
Results¶
Once you've implemented q5_a, execute the cells below to visualize and save your results
q5a_save_results(q5_a)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
Epoch 1/5, Batch 50/477, Loss: 2.8145 Epoch 1/5, Batch 100/477, Loss: 2.4764 Epoch 1/5, Batch 150/477, Loss: 2.3922 Epoch 1/5, Batch 200/477, Loss: 2.3275 Epoch 1/5, Batch 250/477, Loss: 2.2068 Epoch 1/5, Batch 300/477, Loss: 2.0961 Epoch 1/5, Batch 350/477, Loss: 2.0128 Epoch 1/5, Batch 400/477, Loss: 1.9417 Epoch 1/5, Batch 450/477, Loss: 1.8972 Epoch 1/5 completed. Test Loss: 1.8999 Epoch 2/5, Batch 50/477, Loss: 1.8257 Epoch 2/5, Batch 100/477, Loss: 1.7714 Epoch 2/5, Batch 150/477, Loss: 1.7343 Epoch 2/5, Batch 200/477, Loss: 1.7065 Epoch 2/5, Batch 250/477, Loss: 1.6745 Epoch 2/5, Batch 300/477, Loss: 1.6477 Epoch 2/5, Batch 350/477, Loss: 1.6189 Epoch 2/5, Batch 400/477, Loss: 1.5987 Epoch 2/5, Batch 450/477, Loss: 1.5895 Epoch 2/5 completed. Test Loss: 1.6667 Epoch 3/5, Batch 50/477, Loss: 1.5469 Epoch 3/5, Batch 100/477, Loss: 1.5355 Epoch 3/5, Batch 150/477, Loss: 1.5236 Epoch 3/5, Batch 200/477, Loss: 1.5015 Epoch 3/5, Batch 250/477, Loss: 1.5034 Epoch 3/5, Batch 300/477, Loss: 1.4954 Epoch 3/5, Batch 350/477, Loss: 1.4756 Epoch 3/5, Batch 400/477, Loss: 1.4742 Epoch 3/5, Batch 450/477, Loss: 1.4636 Epoch 3/5 completed. Test Loss: 1.6080 Epoch 4/5, Batch 50/477, Loss: 1.4474 Epoch 4/5, Batch 100/477, Loss: 1.4380 Epoch 4/5, Batch 150/477, Loss: 1.4350 Epoch 4/5, Batch 200/477, Loss: 1.4301 Epoch 4/5, Batch 250/477, Loss: 1.4225 Epoch 4/5, Batch 300/477, Loss: 1.4228 Epoch 4/5, Batch 350/477, Loss: 1.4133 Epoch 4/5, Batch 400/477, Loss: 1.4063 Epoch 4/5, Batch 450/477, Loss: 1.3998 Epoch 4/5 completed. Test Loss: 1.6020 Epoch 5/5, Batch 50/477, Loss: 1.3968 Epoch 5/5, Batch 100/477, Loss: 1.4050 Epoch 5/5, Batch 150/477, Loss: 1.3925 Epoch 5/5, Batch 200/477, Loss: 1.3913 Epoch 5/5, Batch 250/477, Loss: 1.4003 Epoch 5/5, Batch 300/477, Loss: 1.3889 Epoch 5/5, Batch 350/477, Loss: 1.3907 Epoch 5/5, Batch 400/477, Loss: 1.3959 Epoch 5/5, Batch 450/477, Loss: 1.4015 Epoch 5/5 completed. Test Loss: 1.6017 Final Test Loss: 1.6017
Sample 1
Doe gods "Hen the Spype and "yde with eyes
Not little kins hise outbing through the bair,
And with your bronze younge tongue
Sample 2
Those dress which forth did golden she
Of heaving in white do:
There she's sprungth,
And thee but them burn.
Sleeps on it
Sample 3
Nymphings love's the long Time, and leaves store;
where your stouth with the floor story-flocks
To still was forth
Sample 4
He heaven thou be small at like a certain,
With the risely plant that wait
With they death may the Chokall straight-19Kobe,
Sample 5
A from Poetrynel's beast
And so song: from grown and his grace,
Free, which making the eyes content,
Leaves of right, what b
Question 6: Causal Transformer: Multimodal¶
So far, we have been dealing only with autoregressive generation of a single modality. Now we will train a model that operates on multiple modalities!
We will use the text labeled colored MNIST dataset, which has a text description of the MNIST image. Run the cell below to visualize the data along with the text annotation. This is the Colored MNIST v2 dataset, which also comes with these text labels.
visualize_q6_data()
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data
Part (a) Multimodal Text and Image Generation¶
Implement and train an autoregressive (AR) model capable of handling both text and image data. The model should be designed to process sequences composed of concatenated text and image tokens in both orders (text followed by images and images followed by text). Additionally, the model should be capable of generating unconditional text and image samples.
Data Preprocessing:
- Text Tokens: Map each unique word in the text data to a unique token. (Note that all text descriptions contain the exact same amount of words. This simplifies text processing, as you won't have to deal with sequences of different lengths as in Question 5)
- Image Tokens: Quantize the image data into tokens using the VQVAE tokenizer from Problem 4.
- In this problem, we have 2 modalities. Introduce an <end of text> token and an <end of image> token. After seeing such a token, the model should switch to sampling the next modality.
- Formulate batches as sequences of concat([<end of image>, text_tokens, <end of text>, image_tokens]) and concat([<end of text>, image_tokens, <end of image>, text_tokens]). With a 50/50 split between each ordering.
Inference:
- During inference, we cannot mix modality tokens. During sampling we can restrict the logits to only be within the relevant modality.
- After <end of image>, only allow the model to sample text tokens (including <end of text>)
- After <end of text>, only allow the model to sample image tokens (including <end of image>)
- At the very start (conditioned on the <bos> token, only allow the model to sample one of (<end of image> or <end of text>))
- As the model may not always correctly sample the <end of image> token before the image ends, you may add a rule to force the model to always sample the correct number of image tokens (49 tokens).
You can use the same hyperparameters as in 4(b) (but of course, feel free to tune your model to achieve better performance)
You will provide these deliverables
- Over the course of training, record the average negative log-likelihood (nats / dim) of the training data (per minibatch) and test data (for your entire test set). Code is provided that automatically plots the training curves.
- Report the final test set performance of your final model
- 9 conditional samples based on provided text.
- 9 conditional samples based on provided images.
- 9 unconditional samples showcasing the model's capability in generating standalone text and images.
# multiple model architecture version
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.0):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
"""
q: (batch_size, n_heads, seq_len, head_size)
k: (batch_size, n_heads, seq_len, head_size)
v: (batch_size, n_heads, seq_len, head_size)
"""
d_k = q.shape[-1]
scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k) # (batch_size, n_heads, seq_len, seq_len)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1) # (batch_size, n_heads, seq_len, seq_len)
attention_weights = self.dropout(attention_weights) # (batch_size, n_heads, seq_len, seq_len)
output = torch.matmul(attention_weights, v) # (batch_size, n_heads, seq_len, head_size)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0, cache=False):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_size = d_model // n_heads
self.use_cache = cache
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.attention = ScaledDotProductAttention(dropout=dropout)
self.cached_k = None
self.cached_v = None
def split_heads(self, x):
"""
x: (batch_size, seq_len, d_model)
"""
batch_size, seq_len, d_model = x.shape
return x.view(batch_size, seq_len, self.n_heads, self.head_size).transpose(1, 2) # (batch_size, n_heads, seq_len, head_size)
def combine_heads(self, x):
"""
x: (batch_size, n_heads, seq_len, head_size)
"""
batch_size, n_heads, seq_len, head_size = x.shape
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) # (batch_size, seq_len, d_model)
def forward(self, x, mask=None, use_cache=False, past_key_values=None):
batch_size, seq_len, d_model = x.shape
if past_key_values is not None:
self.cached_k, self.cached_v = past_key_values
q = self.W_q(x) # (batch_size, seq_len, d_model)
k = self.W_k(x)
v = self.W_v(x)
q = self.split_heads(q) # (batch_size, n_heads, seq_len, head_size)
k = self.split_heads(k)
v = self.split_heads(v)
# Use KV cache if enabled
if use_cache and self.cached_k is not None and self.cached_v is not None:
# Concatenate current k, v with cached k, v
k = torch.cat([self.cached_k, k], dim=2)
v = torch.cat([self.cached_v, v], dim=2)
self.cached_k = k
self.cached_v = v
# Create causal mask if needed
if mask is None:
# If using cache, adjust mask to account for the full sequence length
full_seq_len = k.size(2)
# For cached version, we need to adjust the mask to allow attention to all past tokens
if use_cache and self.cached_k is not None:
# Create a mask where current tokens can attend to all previous tokens
# Current sequence position is at seq_len
seq_position = seq_len
# Create a mask that allows each token to see itself and all previous tokens
mask = torch.ones(seq_len, full_seq_len).to(x.device)
# Make it causal by setting future positions to 0
mask[:, seq_position:] = 0
else:
# Standard causal mask for the full sequence
mask = torch.tril(torch.ones(full_seq_len, full_seq_len)).to(x.device)
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
# Use the attention module directly
output = self.attention(q, k, v, mask) # (batch_size, n_heads, seq_len, head_size)
# Combine heads
output = self.combine_heads(output) # (batch_size, seq_len, d_model)
past_key_values = (k, v)
if use_cache:
return self.dropout(self.out(output)) , past_key_values
else:
return self.dropout(self.out(output))
def clear_cache(self):
self.cached_k = None
self.cached_v = None
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1, use_cache=False):
super().__init__()
self.masked_mha = MultiHeadAttention(d_model, n_heads, dropout, cache=use_cache)
self.layer_norm1 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
nn.Dropout(dropout)
)
self.layer_norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, use_cache=False, past_key_values=None):
# Self-attention with residual connection and layer normalization
residual = x
x = self.layer_norm1(x) # Pre-norm architecture
if use_cache and past_key_values is not None:
x, past_key_values = self.masked_mha(x, use_cache=use_cache, past_key_values=past_key_values)
else:
x = self.masked_mha(x)
x = residual + x # Residual connection
# Feed forward with residual connection and layer normalization
residual = x
x = self.layer_norm2(x) # Pre-norm architecture
x = self.feed_forward(x)
x = residual + x # Residual connection
if use_cache:
return x , past_key_values
else:
return x
def clear_cache(self):
self.masked_mha.clear_cache()
class iGPT(nn.Module):
def __init__(self, vocab_size, context_length, d_model, n_heads, n_layers, dropout=0.1, use_cache=False):
super().__init__()
self.vocab_size = vocab_size
self.context_length = context_length
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.dropout = dropout
self.use_cache = use_cache
# Token embedding
self.token_embedding = nn.Embedding(vocab_size, d_model)
# Positional embedding (learned, as per iGPT specs)
self.position_embedding = nn.Embedding(context_length, d_model)
# Dropout
self.dropout = nn.Dropout(dropout)
# Stack of decoder layers
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, dropout, use_cache=use_cache)
for _ in range(n_layers)
])
# Final layer norm
self.layer_norm = nn.LayerNorm(d_model)
# Output projection
self.output_projection = nn.Linear(d_model, vocab_size)
def forward(self, x, past_key_values=None, use_cache=False):
# x shape: (batch_size, seq_len)
batch_size, seq_len = x.shape
device = x.device
# Create position indices
positions = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
# Get embeddings
token_emb = self.token_embedding(x) # (batch_size, seq_len, d_model)
pos_emb = self.position_embedding(positions) # (batch_size, seq_len, d_model)
# Combine embeddings
x = token_emb + pos_emb # (batch_size, seq_len, d_model)
x = self.dropout(x)
# Apply decoder layers
past_key_values = None
for layer in self.decoder_layers:
if use_cache:
x, past_key_values = layer(x, use_cache=use_cache, past_key_values=past_key_values)
else:
x = layer(x)
# Apply final layer norm
x = self.layer_norm(x) # (batch_size, seq_len, d_model)
# Project to vocabulary
logits = self.output_projection(x) # (batch_size, seq_len, vocab_size)
if use_cache:
return logits, past_key_values
else:
return logits
def clear_cache(self):
for layer in self.decoder_layers:
layer.clear_cache()
import math
def evaluate_model(model, data_loader, sequence_length, vocab_size, device):
"""
Evaluates model performance on a dataset.
Args:
model: The iGPT model
data_loader: DataLoader containing tokenized sequences (already includes BOS token)
sequence_length: Length of token sequences including <bos>
vocab_size: Size of vocabulary
device: Device to run evaluation on
Returns:
Average loss (negative log-likelihood) per dimension
"""
model.eval()
total_loss = 0
total_samples = 0
with torch.no_grad():
for data in data_loader:
data = data.to(device) # Shape: (batch_size, sequence_length)
batch_size = data.size(0)
# Data already includes BOS token at the beginning
# Create input sequence (all tokens except the last one)
input_seq = data[:, :-1] # Shape: (batch_size, sequence_length-1)
# Create targets (all tokens except the first BOS token)
targets = data[:, 1:] # Shape: (batch_size, sequence_length-1)
# Forward pass
logits = model(input_seq) # Shape: (batch_size, sequence_length-1, vocab_size)
# Compute loss
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1), reduction='sum')
total_loss += loss.item()
total_samples += batch_size * (sequence_length - 1)
return total_loss / total_samples
def train_igpt(model, train_loader, test_loader, sequence_length, vocab_size,
device, num_epochs, learning_rate):
"""
Trains the iGPT model.
Args:
model: The iGPT model to train
train_loader: DataLoader for training data (already includes BOS token)
test_loader: DataLoader for test data (already includes BOS token)
sequence_length: Length of token sequences including <bos>
vocab_size: Size of vocabulary
device: Device to train on
num_epochs: Number of training epochs
learning_rate: Initial learning rate
Returns:
train_losses: Array of training losses per minibatch
test_losses: Array of test losses per epoch
"""
# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Learning rate scheduler with warmup and cosine decay
warmup_steps = 1000
total_steps = len(train_loader) * num_epochs
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
else:
decay_ratio = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Initialize arrays to store losses
train_losses = []
test_losses = [evaluate_model(model, test_loader, sequence_length, vocab_size, device)]
# Training loop
for epoch in range(num_epochs):
model.train()
epoch_losses = []
for (batch_idx,data) in enumerate(train_loader):
data = data.to(device) # Shape: (batch_size, sequence_length)
batch_size = data.size(0)
# Data already includes BOS token at the beginning
# Create input sequence (all tokens except the last one)
input_seq = data[:, :-1] # Shape: (batch_size, sequence_length-1)
# Create targets (all tokens except the first BOS token)
targets = data[:, 1:] # Shape: (batch_size, sequence_length-1)
# Forward pass
logits = model(input_seq) # Shape: (batch_size, sequence_length-1, vocab_size)
# Compute loss
loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
# Record loss
train_losses.append(loss.item())
epoch_losses.append(loss.item())
if batch_idx % 50 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
# Evaluate on test set after each epoch
test_loss = evaluate_model(model, test_loader, sequence_length, vocab_size, device)
test_losses.append(test_loss)
print(f"Epoch {epoch+1}/{num_epochs} completed. Test Loss: {test_loss:.4f}")
return np.array(train_losses), np.array(test_losses)
class Tokenizer:
def __init__(self, texts, offset):
self.texts = texts
self.offset = offset
self.all_words = set()
for text in texts:
self.all_words.update(text.split())
# Convert set to list for consistent ordering
self.all_words = list(self.all_words)
self.vocab_size = len(self.all_words)
# Add special tokens after calculating vocab_size
# Reserve token 0 for BOS token
self.bos_token = 0
self.end_of_text_token = self.vocab_size + self.offset
self.end_of_image_token = self.vocab_size + 1 + self.offset
self.all_words.extend(['<end_of_text>', '<end_of_image>'])
# Create mappings with offset applied (starting from 1 to reserve 0 for BOS)
self.word_to_id = {word: i + 1 + self.offset for i, word in enumerate(self.all_words)}
self.id_to_word = {i + 1 + self.offset: word for i, word in enumerate(self.all_words)}
# Add BOS token to mappings
self.id_to_word[self.bos_token] = '<bos>'
def text_encode(self, text):
tokens = [self.word_to_id[word] for word in text.split()]
return torch.tensor(tokens)
def text_decode(self, tokens):
return ' '.join([self.id_to_word[token] for token in tokens if token != self.end_of_text_token and token != self.bos_token])
def create_dataset(images, texts, vqvae, text_tokenizer, batch_size):
# create a dataset of images and texts
dataset = []
bos_token = text_tokenizer.bos_token
end_of_image_token = text_tokenizer.end_of_image_token
end_of_text_token = text_tokenizer.end_of_text_token
print(f"Creating dataset from {len(images)} samples...")
# Pre-tokenize all text data at once for efficiency
print("Pre-tokenizing all text data...")
all_text_tokens = [text_tokenizer.text_encode(text) for text in texts]
# Batch process images for VQVAE quantization
print("Batch processing images...")
batch_size_process = 128
all_image_tokens = []
for i in range(0, len(images), batch_size_process):
batch_end = min(i + batch_size_process, len(images))
batch_images = images[i:batch_end]
# Process batch of images
batch_image_tokens = vqvae.quantize(batch_images)
# Flatten each image's tokens and store
for j in range(batch_image_tokens.shape[0]):
image_tokens_flat = batch_image_tokens[j].flatten()
all_image_tokens.append(image_tokens_flat)
if i % (batch_size_process * 1000) == 0:
print(f"Processed {min(i + batch_size_process, len(images))}/{len(images)} images ({min(i + batch_size_process, len(images))/len(images)*1000:.1f}%)")
# Create special token tensors once
bos_tensor = torch.tensor([bos_token])
end_of_image_tensor = torch.tensor([end_of_image_token])
end_of_text_tensor = torch.tensor([end_of_text_token])
print("Assembling dataset...")
for idx in range(len(texts)):
text_tokens = all_text_tokens[idx]
image_tokens_flat = all_image_tokens[idx]
if idx % 2 == 0:
# text followed by image
complete_tokens = torch.cat((bos_tensor, end_of_image_tensor, text_tokens, end_of_text_tensor, image_tokens_flat))
dataset.append(complete_tokens)
else:
# image followed by text
complete_tokens = torch.cat((bos_tensor, end_of_text_tensor, image_tokens_flat, end_of_image_tensor, text_tokens))
dataset.append(complete_tokens)
print(f"Dataset creation complete! Total samples: {len(dataset)}")
print(f"Creating DataLoader with batch_size={batch_size}")
# create dataloader
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
def generate_conditional_samples_from_text(model, text_tokenizer, vqvae, text_prompts, device, max_length=58):
"""
Generate images conditioned on text prompts.
Args:
model: Trained iGPT model
text_tokenizer: Text tokenizer
vqvae: VQVAE model for decoding image tokens
text_prompts: List of text strings to condition on
device: Device to run on
max_length: Maximum sequence length
Returns:
List of (image, text) tuples
"""
model.eval()
samples = []
with torch.no_grad():
for text_prompt in text_prompts:
# Start with BOS token and end_of_image token, then text tokens, then end_of_text token
text_tokens = text_tokenizer.text_encode(text_prompt)
input_seq = torch.cat([
torch.tensor([text_tokenizer.bos_token]),
torch.tensor([text_tokenizer.end_of_image_token]),
text_tokens,
torch.tensor([text_tokenizer.end_of_text_token])
]).unsqueeze(0).to(device)
# Generate 49 image tokens
for _ in range(49): # 7x7 = 49 image tokens
logits = model(input_seq)
next_token_logits = logits[0, -1, :]
# Restrict to image tokens only (0 to vqvae.n_embeddings-1)
mask = torch.zeros_like(next_token_logits)
mask[:vqvae.n_embeddings] = 1
next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
# Append to sequence
input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
# Extract image tokens and decode
image_tokens = input_seq[0, -49:].cpu().numpy().reshape(7, 7)
decoded_image = vqvae.decode(image_tokens.reshape(1, 7, 7))[0]
samples.append((decoded_image, text_prompt))
return samples
def generate_conditional_samples_from_image(model, text_tokenizer, vqvae, image_prompts, device, max_length=58):
"""
Generate text conditioned on image prompts.
Args:
model: Trained iGPT model
text_tokenizer: Text tokenizer
vqvae: VQVAE model for encoding image tokens
image_prompts: Array of images to condition on
device: Device to run on
max_length: Maximum sequence length
Returns:
List of (image, text) tuples
"""
model.eval()
samples = []
with torch.no_grad():
for image_prompt in image_prompts:
# Quantize the image
image_tokens = vqvae.quantize(image_prompt.reshape(1, *image_prompt.shape))[0].flatten()
# Start with BOS token, end_of_text token, image tokens, then end_of_image token
input_seq = torch.cat([
torch.tensor([text_tokenizer.bos_token]),
torch.tensor([text_tokenizer.end_of_text_token]),
torch.tensor(image_tokens),
torch.tensor([text_tokenizer.end_of_image_token])
]).unsqueeze(0).to(device)
# Generate text tokens (typically 6 words based on the dataset)
generated_text_tokens = []
for _ in range(6): # Assuming 6 words per text description
logits = model(input_seq)
next_token_logits = logits[0, -1, :]
# Restrict to text tokens only (excluding special tokens)
mask = torch.zeros_like(next_token_logits)
# Text tokens start from vqvae.n_embeddings + 1 (excluding BOS which is 0)
for word, token_id in text_tokenizer.word_to_id.items():
if word not in ['<end_of_text>', '<end_of_image>']:
mask[token_id] = 1
next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
# Sample next token
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
generated_text_tokens.append(next_token.item())
# Append to sequence
input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
# Decode text
generated_text = text_tokenizer.text_decode(generated_text_tokens)
samples.append((image_prompt, generated_text))
return samples
def generate_unconditional_samples(model, text_tokenizer, vqvae, device, num_samples=9, max_length=58):
"""
Generate unconditional samples (both text and images).
Args:
model: Trained iGPT model
text_tokenizer: Text tokenizer
vqvae: VQVAE model for decoding
device: Device to run on
num_samples: Number of samples to generate
max_length: Maximum sequence length
Returns:
List of (image, text) tuples
"""
model.eval()
samples = []
with torch.no_grad():
for _ in range(num_samples):
# Start with BOS token
input_seq = torch.tensor([text_tokenizer.bos_token]).unsqueeze(0).to(device)
# First, decide which modality to start with
logits = model(input_seq)
next_token_logits = logits[0, -1, :]
# Only allow end_of_image or end_of_text tokens
mask = torch.zeros_like(next_token_logits)
mask[text_tokenizer.end_of_image_token] = 1
mask[text_tokenizer.end_of_text_token] = 1
next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
probs = torch.softmax(next_token_logits, dim=-1)
modality_token = torch.multinomial(probs, 1)
input_seq = torch.cat([input_seq, modality_token.unsqueeze(0)], dim=1)
if modality_token.item() == text_tokenizer.end_of_image_token:
# Generate text first, then image
# Generate 6 text tokens
for _ in range(6):
logits = model(input_seq)
next_token_logits = logits[0, -1, :]
# Restrict to text tokens
mask = torch.zeros_like(next_token_logits)
for word, token_id in text_tokenizer.word_to_id.items():
if word not in ['<end_of_text>', '<end_of_image>']:
mask[token_id] = 1
next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
# Add end_of_text token
input_seq = torch.cat([input_seq, torch.tensor([text_tokenizer.end_of_text_token]).unsqueeze(0).to(device)], dim=1)
# Generate 49 image tokens
for _ in range(49):
logits = model(input_seq)
next_token_logits = logits[0, -1, :]
# Restrict to image tokens
mask = torch.zeros_like(next_token_logits)
mask[:vqvae.n_embeddings] = 1
next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
# Extract text and image
text_tokens = input_seq[0, 2:8].cpu().numpy() # Skip BOS, end_of_image, get 6 text tokens
image_tokens = input_seq[0, -49:].cpu().numpy().reshape(7, 7)
else: # end_of_text_token
# Generate image first, then text
# Generate 49 image tokens
for _ in range(49):
logits = model(input_seq)
next_token_logits = logits[0, -1, :]
# Restrict to image tokens
mask = torch.zeros_like(next_token_logits)
mask[:vqvae.n_embeddings] = 1
next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
# Add end_of_image token
input_seq = torch.cat([input_seq, torch.tensor([text_tokenizer.end_of_image_token]).unsqueeze(0).to(device)], dim=1)
# Generate 6 text tokens
for _ in range(6):
logits = model(input_seq)
next_token_logits = logits[0, -1, :]
# Restrict to text tokens
mask = torch.zeros_like(next_token_logits)
for word, token_id in text_tokenizer.word_to_id.items():
if word not in ['<end_of_text>', '<end_of_image>']:
mask[token_id] = 1
next_token_logits = next_token_logits * mask + (1 - mask) * (-1e9)
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, 1)
input_seq = torch.cat([input_seq, next_token.unsqueeze(0)], dim=1)
# Extract image and text
image_tokens = input_seq[0, 2:51].cpu().numpy().reshape(7, 7) # Skip BOS, end_of_text, get 49 image tokens
text_tokens = input_seq[0, -6:].cpu().numpy() # Get last 6 text tokens
# Decode
decoded_image = vqvae.decode(image_tokens.reshape(1, 7, 7))[0]
decoded_text = text_tokenizer.text_decode(text_tokens)
samples.append((decoded_image, decoded_text))
return samples
def q6_a(train_data, test_data, image_shape, train_text, test_text, image_test_prompt, text_test_prompt, vqvae):
"""
train_data: A (n_train, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
test_data: A (n_test, H, W, C) uint8 numpy array of color images with values in {0, 1, 2, 3}
image_shape: tuple (H, W, C) The shape of the images in the dataset, indicating height, width, and number of color channels.
train_text: list[str] Text data associated with each training image.
test_text: list[str] Text data associated with each test image.
image_test_prompt: (9, H, W, C) Image data used for generating conditional text samples during testing.
text_test_prompt: list of 9 strings Text prompts used for generating conditional image samples during testing.
vqvae: a vqvae model, trained on the relevant dataset
Returns
- a (# of training iterations,) numpy array of train_losses evaluated every minibatch
- a (# of epochs + 1,) numpy array of test_losses evaluated once at initialization and after each epoch
- a list of 9 (image, text), corresponding to the image conditioned samples
- a list of 9 (image, text), corresponding to the text conditions samples
- a list of 9 (image, text), corresponding to unconditional samples
"""
# Fix the offset parameter for the tokenizer - it should be the vocab_size, not 0
text_tokenizer = Tokenizer(train_text, vqvae.n_embeddings)
H, W, C = image_shape
batch_size = 128
learning_rate = 1e-3
num_epochs = 30
d_model = 128
n_heads = 4
n_layers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# determine sequence length and vocab size
sequence_length = 58 # 49 + 6 +2 + 1
# Total vocab size should include both image tokens and text tokens
total_vocab_size = vqvae.n_embeddings + len(text_tokenizer.all_words)
# get subset of data to test first
train_loader = create_dataset(train_data, train_text, vqvae, text_tokenizer, batch_size)
test_loader = create_dataset(test_data, test_text, vqvae, text_tokenizer, batch_size)
model = iGPT(total_vocab_size, sequence_length, d_model, n_heads, n_layers).to(device)
train_losses, test_losses = train_igpt(model, train_loader, test_loader,
sequence_length, total_vocab_size, device,
num_epochs, learning_rate)
# Generate samples
samples_text_conditioned = generate_conditional_samples_from_text(
model, text_tokenizer, vqvae, text_test_prompt, device
)
samples_image_conditioned = generate_conditional_samples_from_image(
model, text_tokenizer, vqvae, image_test_prompt, device
)
samples_unconditioned = generate_unconditional_samples(
model, text_tokenizer, vqvae, device, num_samples=9
)
return train_losses, test_losses, samples_image_conditioned, samples_text_conditioned, samples_unconditioned
Results¶
Once you've implemented q6_a, execute the cells below to visualize and save your results
q6a_save_results(q6_a)
data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data data_dir: /home/nghiaph/workspace/deepul/homeworks/hw1/data Creating dataset from 60000 samples... Pre-tokenizing all text data... Batch processing images... Processed 128/60000 images (2.1%) Assembling dataset... Dataset creation complete! Total samples: 60000 Creating DataLoader with batch_size=128 Creating dataset from 10000 samples... Pre-tokenizing all text data... Batch processing images... Processed 128/10000 images (12.8%) Assembling dataset... Dataset creation complete! Total samples: 10000 Creating DataLoader with batch_size=128 Epoch 1/30, Batch 0/469, Loss: 7.0951 Epoch 1/30, Batch 50/469, Loss: 6.9365 Epoch 1/30, Batch 100/469, Loss: 6.4238 Epoch 1/30, Batch 150/469, Loss: 5.9999 Epoch 1/30, Batch 200/469, Loss: 5.5952 Epoch 1/30, Batch 250/469, Loss: 5.0091 Epoch 1/30, Batch 300/469, Loss: 4.3433 Epoch 1/30, Batch 350/469, Loss: 3.9322 Epoch 1/30, Batch 400/469, Loss: 3.8381 Epoch 1/30, Batch 450/469, Loss: 3.7479 Epoch 1/30 completed. Test Loss: 3.6811 Epoch 2/30, Batch 0/469, Loss: 3.7475 Epoch 2/30, Batch 50/469, Loss: 3.6743 Epoch 2/30, Batch 100/469, Loss: 3.6187 Epoch 2/30, Batch 150/469, Loss: 3.6176 Epoch 2/30, Batch 200/469, Loss: 3.5434 Epoch 2/30, Batch 250/469, Loss: 3.6104 Epoch 2/30, Batch 300/469, Loss: 3.5620 Epoch 2/30, Batch 350/469, Loss: 3.4631 Epoch 2/30, Batch 400/469, Loss: 3.4442 Epoch 2/30, Batch 450/469, Loss: 3.4916 Epoch 2/30 completed. Test Loss: 3.3927 Epoch 3/30, Batch 0/469, Loss: 3.4020 Epoch 3/30, Batch 50/469, Loss: 3.4850 Epoch 3/30, Batch 100/469, Loss: 3.3626 Epoch 3/30, Batch 150/469, Loss: 3.3485 Epoch 3/30, Batch 200/469, Loss: 3.3734 Epoch 3/30, Batch 250/469, Loss: 3.2823 Epoch 3/30, Batch 300/469, Loss: 3.2191 Epoch 3/30, Batch 350/469, Loss: 3.2746 Epoch 3/30, Batch 400/469, Loss: 3.3012 Epoch 3/30, Batch 450/469, Loss: 3.2492 Epoch 3/30 completed. Test Loss: 3.1725 Epoch 4/30, Batch 0/469, Loss: 3.1248 Epoch 4/30, Batch 50/469, Loss: 3.2201 Epoch 4/30, Batch 100/469, Loss: 3.2481 Epoch 4/30, Batch 150/469, Loss: 3.2488 Epoch 4/30, Batch 200/469, Loss: 3.1117 Epoch 4/30, Batch 250/469, Loss: 3.0753 Epoch 4/30, Batch 300/469, Loss: 3.1594 Epoch 4/30, Batch 350/469, Loss: 3.1258 Epoch 4/30, Batch 400/469, Loss: 3.0861 Epoch 4/30, Batch 450/469, Loss: 3.0951 Epoch 4/30 completed. Test Loss: 3.0241 Epoch 5/30, Batch 0/469, Loss: 3.0616 Epoch 5/30, Batch 50/469, Loss: 3.0368 Epoch 5/30, Batch 100/469, Loss: 3.0523 Epoch 5/30, Batch 150/469, Loss: 3.0173 Epoch 5/30, Batch 200/469, Loss: 2.9779 Epoch 5/30, Batch 250/469, Loss: 3.0774 Epoch 5/30, Batch 300/469, Loss: 3.0289 Epoch 5/30, Batch 350/469, Loss: 3.0879 Epoch 5/30, Batch 400/469, Loss: 2.9733 Epoch 5/30, Batch 450/469, Loss: 3.0253 Epoch 5/30 completed. Test Loss: 2.9287 Epoch 6/30, Batch 0/469, Loss: 2.9740 Epoch 6/30, Batch 50/469, Loss: 2.9356 Epoch 6/30, Batch 100/469, Loss: 3.0952 Epoch 6/30, Batch 150/469, Loss: 2.9799 Epoch 6/30, Batch 200/469, Loss: 2.9494 Epoch 6/30, Batch 250/469, Loss: 2.9638 Epoch 6/30, Batch 300/469, Loss: 2.9206 Epoch 6/30, Batch 350/469, Loss: 2.9036 Epoch 6/30, Batch 400/469, Loss: 3.0060 Epoch 6/30, Batch 450/469, Loss: 2.9352 Epoch 6/30 completed. Test Loss: 2.8623 Epoch 7/30, Batch 0/469, Loss: 2.9634 Epoch 7/30, Batch 50/469, Loss: 2.9687 Epoch 7/30, Batch 100/469, Loss: 2.8994 Epoch 7/30, Batch 150/469, Loss: 2.9184 Epoch 7/30, Batch 200/469, Loss: 2.8395 Epoch 7/30, Batch 250/469, Loss: 2.9313 Epoch 7/30, Batch 300/469, Loss: 2.8314 Epoch 7/30, Batch 350/469, Loss: 2.9302 Epoch 7/30, Batch 400/469, Loss: 2.8441 Epoch 7/30, Batch 450/469, Loss: 2.8525 Epoch 7/30 completed. Test Loss: 2.8188 Epoch 8/30, Batch 0/469, Loss: 2.8840 Epoch 8/30, Batch 50/469, Loss: 2.8966 Epoch 8/30, Batch 100/469, Loss: 2.8827 Epoch 8/30, Batch 150/469, Loss: 2.8918 Epoch 8/30, Batch 200/469, Loss: 2.8461 Epoch 8/30, Batch 250/469, Loss: 2.9240 Epoch 8/30, Batch 300/469, Loss: 2.8935 Epoch 8/30, Batch 350/469, Loss: 2.8889 Epoch 8/30, Batch 400/469, Loss: 2.7807 Epoch 8/30, Batch 450/469, Loss: 2.8786 Epoch 8/30 completed. Test Loss: 2.7761 Epoch 9/30, Batch 0/469, Loss: 2.8313 Epoch 9/30, Batch 50/469, Loss: 2.8219 Epoch 9/30, Batch 100/469, Loss: 2.8402 Epoch 9/30, Batch 150/469, Loss: 2.8810 Epoch 9/30, Batch 200/469, Loss: 2.7957 Epoch 9/30, Batch 250/469, Loss: 2.8398 Epoch 9/30, Batch 300/469, Loss: 2.9297 Epoch 9/30, Batch 350/469, Loss: 2.9504 Epoch 9/30, Batch 400/469, Loss: 2.8850 Epoch 9/30, Batch 450/469, Loss: 2.8889 Epoch 9/30 completed. Test Loss: 2.7530 Epoch 10/30, Batch 0/469, Loss: 2.7714 Epoch 10/30, Batch 50/469, Loss: 2.8695 Epoch 10/30, Batch 100/469, Loss: 2.7365 Epoch 10/30, Batch 150/469, Loss: 2.8018 Epoch 10/30, Batch 200/469, Loss: 2.7453 Epoch 10/30, Batch 250/469, Loss: 2.8007 Epoch 10/30, Batch 300/469, Loss: 2.8421 Epoch 10/30, Batch 350/469, Loss: 2.8254 Epoch 10/30, Batch 400/469, Loss: 2.7566 Epoch 10/30, Batch 450/469, Loss: 2.8254 Epoch 10/30 completed. Test Loss: 2.7228 Epoch 11/30, Batch 0/469, Loss: 2.7551 Epoch 11/30, Batch 50/469, Loss: 2.8213 Epoch 11/30, Batch 100/469, Loss: 2.8193 Epoch 11/30, Batch 150/469, Loss: 2.7120 Epoch 11/30, Batch 200/469, Loss: 2.7770 Epoch 11/30, Batch 250/469, Loss: 2.7634 Epoch 11/30, Batch 300/469, Loss: 2.7723 Epoch 11/30, Batch 350/469, Loss: 2.7986 Epoch 11/30, Batch 400/469, Loss: 2.6951 Epoch 11/30, Batch 450/469, Loss: 2.8092 Epoch 11/30 completed. Test Loss: 2.7056 Epoch 12/30, Batch 0/469, Loss: 2.6845 Epoch 12/30, Batch 50/469, Loss: 2.7566 Epoch 12/30, Batch 100/469, Loss: 2.7650 Epoch 12/30, Batch 150/469, Loss: 2.7342 Epoch 12/30, Batch 200/469, Loss: 2.7513 Epoch 12/30, Batch 250/469, Loss: 2.7280 Epoch 12/30, Batch 300/469, Loss: 2.8334 Epoch 12/30, Batch 350/469, Loss: 2.6839 Epoch 12/30, Batch 400/469, Loss: 2.7525 Epoch 12/30, Batch 450/469, Loss: 2.7641 Epoch 12/30 completed. Test Loss: 2.6857 Epoch 13/30, Batch 0/469, Loss: 2.7559 Epoch 13/30, Batch 50/469, Loss: 2.6989 Epoch 13/30, Batch 100/469, Loss: 2.7418 Epoch 13/30, Batch 150/469, Loss: 2.7183 Epoch 13/30, Batch 200/469, Loss: 2.7075 Epoch 13/30, Batch 250/469, Loss: 2.7152 Epoch 13/30, Batch 300/469, Loss: 2.6814 Epoch 13/30, Batch 350/469, Loss: 2.7818 Epoch 13/30, Batch 400/469, Loss: 2.6887 Epoch 13/30, Batch 450/469, Loss: 2.7272 Epoch 13/30 completed. Test Loss: 2.6783 Epoch 14/30, Batch 0/469, Loss: 2.7592 Epoch 14/30, Batch 50/469, Loss: 2.6698 Epoch 14/30, Batch 100/469, Loss: 2.6606 Epoch 14/30, Batch 150/469, Loss: 2.6669 Epoch 14/30, Batch 200/469, Loss: 2.7069 Epoch 14/30, Batch 250/469, Loss: 2.6557 Epoch 14/30, Batch 300/469, Loss: 2.7253 Epoch 14/30, Batch 350/469, Loss: 2.7223 Epoch 14/30, Batch 400/469, Loss: 2.6526 Epoch 14/30, Batch 450/469, Loss: 2.7666 Epoch 14/30 completed. Test Loss: 2.6631 Epoch 15/30, Batch 0/469, Loss: 2.8077 Epoch 15/30, Batch 50/469, Loss: 2.7296 Epoch 15/30, Batch 100/469, Loss: 2.7253 Epoch 15/30, Batch 150/469, Loss: 2.7593 Epoch 15/30, Batch 200/469, Loss: 2.6732 Epoch 15/30, Batch 250/469, Loss: 2.6974 Epoch 15/30, Batch 300/469, Loss: 2.7120 Epoch 15/30, Batch 350/469, Loss: 2.7051 Epoch 15/30, Batch 400/469, Loss: 2.7432 Epoch 15/30, Batch 450/469, Loss: 2.7349 Epoch 15/30 completed. Test Loss: 2.6495 Epoch 16/30, Batch 0/469, Loss: 2.7292 Epoch 16/30, Batch 50/469, Loss: 2.6720 Epoch 16/30, Batch 100/469, Loss: 2.6532 Epoch 16/30, Batch 150/469, Loss: 2.7154 Epoch 16/30, Batch 200/469, Loss: 2.7005 Epoch 16/30, Batch 250/469, Loss: 2.6644 Epoch 16/30, Batch 300/469, Loss: 2.6986 Epoch 16/30, Batch 350/469, Loss: 2.7245 Epoch 16/30, Batch 400/469, Loss: 2.6717 Epoch 16/30, Batch 450/469, Loss: 2.6643 Epoch 16/30 completed. Test Loss: 2.6388 Epoch 17/30, Batch 0/469, Loss: 2.6515 Epoch 17/30, Batch 50/469, Loss: 2.6336 Epoch 17/30, Batch 100/469, Loss: 2.6795 Epoch 17/30, Batch 150/469, Loss: 2.6871 Epoch 17/30, Batch 200/469, Loss: 2.7344 Epoch 17/30, Batch 250/469, Loss: 2.6723 Epoch 17/30, Batch 300/469, Loss: 2.7224 Epoch 17/30, Batch 350/469, Loss: 2.6828 Epoch 17/30, Batch 400/469, Loss: 2.7290 Epoch 17/30, Batch 450/469, Loss: 2.6904 Epoch 17/30 completed. Test Loss: 2.6326 Epoch 18/30, Batch 0/469, Loss: 2.7198 Epoch 18/30, Batch 50/469, Loss: 2.6155 Epoch 18/30, Batch 100/469, Loss: 2.6426 Epoch 18/30, Batch 150/469, Loss: 2.6718 Epoch 18/30, Batch 200/469, Loss: 2.6358 Epoch 18/30, Batch 250/469, Loss: 2.6954 Epoch 18/30, Batch 300/469, Loss: 2.7013 Epoch 18/30, Batch 350/469, Loss: 2.6637 Epoch 18/30, Batch 400/469, Loss: 2.6466 Epoch 18/30, Batch 450/469, Loss: 2.6998 Epoch 18/30 completed. Test Loss: 2.6201 Epoch 19/30, Batch 0/469, Loss: 2.6950 Epoch 19/30, Batch 50/469, Loss: 2.5979 Epoch 19/30, Batch 100/469, Loss: 2.6393 Epoch 19/30, Batch 150/469, Loss: 2.6538 Epoch 19/30, Batch 200/469, Loss: 2.7432 Epoch 19/30, Batch 250/469, Loss: 2.6518 Epoch 19/30, Batch 300/469, Loss: 2.6267 Epoch 19/30, Batch 350/469, Loss: 2.6606 Epoch 19/30, Batch 400/469, Loss: 2.5855 Epoch 19/30, Batch 450/469, Loss: 2.6239 Epoch 19/30 completed. Test Loss: 2.6183 Epoch 20/30, Batch 0/469, Loss: 2.6392 Epoch 20/30, Batch 50/469, Loss: 2.6116 Epoch 20/30, Batch 100/469, Loss: 2.6269 Epoch 20/30, Batch 150/469, Loss: 2.7098 Epoch 20/30, Batch 200/469, Loss: 2.6827 Epoch 20/30, Batch 250/469, Loss: 2.6657 Epoch 20/30, Batch 300/469, Loss: 2.6737 Epoch 20/30, Batch 350/469, Loss: 2.6417 Epoch 20/30, Batch 400/469, Loss: 2.6517 Epoch 20/30, Batch 450/469, Loss: 2.5823 Epoch 20/30 completed. Test Loss: 2.6076 Epoch 21/30, Batch 0/469, Loss: 2.6687 Epoch 21/30, Batch 50/469, Loss: 2.6462 Epoch 21/30, Batch 100/469, Loss: 2.6730 Epoch 21/30, Batch 150/469, Loss: 2.6893 Epoch 21/30, Batch 200/469, Loss: 2.6747 Epoch 21/30, Batch 250/469, Loss: 2.7593 Epoch 21/30, Batch 300/469, Loss: 2.6666 Epoch 21/30, Batch 350/469, Loss: 2.7037 Epoch 21/30, Batch 400/469, Loss: 2.6659 Epoch 21/30, Batch 450/469, Loss: 2.6344 Epoch 21/30 completed. Test Loss: 2.6016 Epoch 22/30, Batch 0/469, Loss: 2.5732 Epoch 22/30, Batch 50/469, Loss: 2.6851 Epoch 22/30, Batch 100/469, Loss: 2.6935 Epoch 22/30, Batch 150/469, Loss: 2.7070 Epoch 22/30, Batch 200/469, Loss: 2.6485 Epoch 22/30, Batch 250/469, Loss: 2.6468 Epoch 22/30, Batch 300/469, Loss: 2.5894 Epoch 22/30, Batch 350/469, Loss: 2.6464 Epoch 22/30, Batch 400/469, Loss: 2.6810 Epoch 22/30, Batch 450/469, Loss: 2.6639 Epoch 22/30 completed. Test Loss: 2.5993 Epoch 23/30, Batch 0/469, Loss: 2.5770 Epoch 23/30, Batch 50/469, Loss: 2.6658 Epoch 23/30, Batch 100/469, Loss: 2.6333 Epoch 23/30, Batch 150/469, Loss: 2.6432 Epoch 23/30, Batch 200/469, Loss: 2.7187 Epoch 23/30, Batch 250/469, Loss: 2.6751 Epoch 23/30, Batch 300/469, Loss: 2.6216 Epoch 23/30, Batch 350/469, Loss: 2.6480 Epoch 23/30, Batch 400/469, Loss: 2.5892 Epoch 23/30, Batch 450/469, Loss: 2.7111 Epoch 23/30 completed. Test Loss: 2.5942 Epoch 24/30, Batch 0/469, Loss: 2.5932 Epoch 24/30, Batch 50/469, Loss: 2.6080 Epoch 24/30, Batch 100/469, Loss: 2.5952 Epoch 24/30, Batch 150/469, Loss: 2.5781 Epoch 24/30, Batch 200/469, Loss: 2.6943 Epoch 24/30, Batch 250/469, Loss: 2.6286 Epoch 24/30, Batch 300/469, Loss: 2.6762 Epoch 24/30, Batch 350/469, Loss: 2.6617 Epoch 24/30, Batch 400/469, Loss: 2.6114 Epoch 24/30, Batch 450/469, Loss: 2.6298 Epoch 24/30 completed. Test Loss: 2.5928 Epoch 25/30, Batch 0/469, Loss: 2.6036 Epoch 25/30, Batch 50/469, Loss: 2.7137 Epoch 25/30, Batch 100/469, Loss: 2.6335 Epoch 25/30, Batch 150/469, Loss: 2.5694 Epoch 25/30, Batch 200/469, Loss: 2.6246 Epoch 25/30, Batch 250/469, Loss: 2.5550 Epoch 25/30, Batch 300/469, Loss: 2.5811 Epoch 25/30, Batch 350/469, Loss: 2.6790 Epoch 25/30, Batch 400/469, Loss: 2.6278 Epoch 25/30, Batch 450/469, Loss: 2.6187 Epoch 25/30 completed. Test Loss: 2.5892 Epoch 26/30, Batch 0/469, Loss: 2.6463 Epoch 26/30, Batch 50/469, Loss: 2.6439 Epoch 26/30, Batch 100/469, Loss: 2.6300 Epoch 26/30, Batch 150/469, Loss: 2.5776 Epoch 26/30, Batch 200/469, Loss: 2.6629 Epoch 26/30, Batch 250/469, Loss: 2.6139 Epoch 26/30, Batch 300/469, Loss: 2.6144 Epoch 26/30, Batch 350/469, Loss: 2.6701 Epoch 26/30, Batch 400/469, Loss: 2.5264 Epoch 26/30, Batch 450/469, Loss: 2.5714 Epoch 26/30 completed. Test Loss: 2.5870 Epoch 27/30, Batch 0/469, Loss: 2.6233 Epoch 27/30, Batch 50/469, Loss: 2.6410 Epoch 27/30, Batch 100/469, Loss: 2.6642 Epoch 27/30, Batch 150/469, Loss: 2.5993 Epoch 27/30, Batch 200/469, Loss: 2.6415 Epoch 27/30, Batch 250/469, Loss: 2.6294 Epoch 27/30, Batch 300/469, Loss: 2.5992 Epoch 27/30, Batch 350/469, Loss: 2.6654 Epoch 27/30, Batch 400/469, Loss: 2.4930 Epoch 27/30, Batch 450/469, Loss: 2.6011 Epoch 27/30 completed. Test Loss: 2.5856 Epoch 28/30, Batch 0/469, Loss: 2.6330 Epoch 28/30, Batch 50/469, Loss: 2.6360 Epoch 28/30, Batch 100/469, Loss: 2.6049 Epoch 28/30, Batch 150/469, Loss: 2.6217 Epoch 28/30, Batch 200/469, Loss: 2.5843 Epoch 28/30, Batch 250/469, Loss: 2.6726 Epoch 28/30, Batch 300/469, Loss: 2.6267 Epoch 28/30, Batch 350/469, Loss: 2.5920 Epoch 28/30, Batch 400/469, Loss: 2.6712 Epoch 28/30, Batch 450/469, Loss: 2.6426 Epoch 28/30 completed. Test Loss: 2.5855 Epoch 29/30, Batch 0/469, Loss: 2.6381 Epoch 29/30, Batch 50/469, Loss: 2.6717 Epoch 29/30, Batch 100/469, Loss: 2.5929 Epoch 29/30, Batch 150/469, Loss: 2.6720 Epoch 29/30, Batch 200/469, Loss: 2.5194 Epoch 29/30, Batch 250/469, Loss: 2.6219 Epoch 29/30, Batch 300/469, Loss: 2.5707 Epoch 29/30, Batch 350/469, Loss: 2.5919 Epoch 29/30, Batch 400/469, Loss: 2.6117 Epoch 29/30, Batch 450/469, Loss: 2.5315 Epoch 29/30 completed. Test Loss: 2.5852 Epoch 30/30, Batch 0/469, Loss: 2.5599 Epoch 30/30, Batch 50/469, Loss: 2.6249 Epoch 30/30, Batch 100/469, Loss: 2.5602 Epoch 30/30, Batch 150/469, Loss: 2.5848 Epoch 30/30, Batch 200/469, Loss: 2.6370 Epoch 30/30, Batch 250/469, Loss: 2.5630 Epoch 30/30, Batch 300/469, Loss: 2.6131 Epoch 30/30, Batch 350/469, Loss: 2.6885 Epoch 30/30, Batch 400/469, Loss: 2.6537 Epoch 30/30, Batch 450/469, Loss: 2.6980 Epoch 30/30 completed. Test Loss: 2.5852
/tmp/ipykernel_2898721/2746847972.py:82: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor). torch.tensor(image_tokens),
Final Test Loss: 2.5852
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-10..262]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-11..177]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-9..266]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-39..306]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-10..259]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-10..266]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-9..266]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [78..266]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [37..275].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-27..319]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-3..224]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-48..294]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-9..262]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-6..286]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [75..266]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-23..286]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [38..268]. Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [18..259].